We want to investigate various convolutional neural networks for the purpose of distinguishing dachshunds from other dog breeds. Our ultimate goal is to build a lightweight model and deploy it to Heroku with Flask with a simple frontend where users can upload their own images and see the result of the classification. You can try out the finished product at https://doxie-detector.herokuapp.com (note: the initial boot-up can take a while on Heroku's free tier). The current notebook is continuation from linear_models.ipynb, where we applied classical linear methods to obtain a baseline accuracy for our problem. By utilising principal component analysis and support vector machines we achieved 61.2% validation accuracy. We now want to see how much we can improve on this with more sophisticated models.
The plan for the rest of the notebook is as follows. We'll first briefly overview the dataset and prepare it for usage in Keras models. Then, we'll train a simple CNN from scratch, which gives us a modest improvement in terms of accuracy. To further improve on this we perform transfer learning on existing models, namely ResNet and EfficientNet. This allows us to obtain a decent accuracy of ~93.6% without much effort. We then investigate the finished model with the help of saliency maps and class activation maps. As an application of the saliency maps, we also develop a novel simple algorithm for object localisation, which does not use any additional input from the network. Finally, we'll briefly investigate two popular methods for general ML model explainability: SHAP and Lime.
Our datasets consists of 2129 pictures of dachshunds and dogs of other breeds spread roughly evenly across both classes. All of the pictures were scraped from various public sources such as Dog API and reddit. The only constraint for the selected images was that each should contain exactly one dog (at least partially visible). No preprocessing or pruning of any kind was done, so the dataset includes plenty of out of focus photos, poorly cropped or centred pictures or photos where only a small part of the dog is visible, to name just a few. So it is possible, for example, that a sample image doesn't show the face of the dog at all, which means that our model will have to learn to be very versatile. Moreover, some of the dogs have multiple pictures of themselves in our dataset which we expect will add extra noise to the validation accuracy.
We split these images into a train and validation set with a 80-20 split and resize each of them to 224x224 pixels. We do not use a proper test set, but instead investigate our final model with a few select images of family dogs (who are not present in the train or validation datasets at all).
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
import os, glob
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import PIL, PIL.Image
import pickle
from IPython.core.display import HTML
import shap
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries
display(HTML("<style>div.jp-RenderedText {overflow-y: auto; max-height: 200px;}</style>"))
DATA_DIR = 'data'
dachshunds = ['data/dachshund/' + i for i in os.listdir('data/dachshund')
if os.path.splitext(i)[1][1:] in ["jpeg", "jpg", "png"]]
PIL.Image.open(dachshunds[5]).resize((224,224)) # 1
We're training our models locally with only one NVIDIA GTX 1660 Ti 6GB. With the following cell we can limit TensorFlow's memory usage to 4GB in order to keep the system operational during training.
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
tf.config.experimental.set_virtual_device_configuration(
gpus[0],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)
1 Physical GPUs, 1 Logical GPUs
AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
SEED = 0
np.random.seed(SEED)
tf.random.set_seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
class_names = ['other', 'dachshund']
train_ds_ = tf.keras.preprocessing.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset='training',
seed=SEED,
image_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_names=class_names
)
val_ds_ = tf.keras.preprocessing.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset='validation',
seed=SEED,
image_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_names=class_names
)
test_ds_ = tf.keras.preprocessing.image_dataset_from_directory(
'test',
image_size=(IMG_HEIGHT, IMG_WIDTH),
class_names=class_names
)
Found 2129 files belonging to 2 classes. Using 1704 files for training. Found 2129 files belonging to 2 classes. Using 425 files for validation. Found 8 files belonging to 2 classes.
cast_img = lambda x, y: (tf.image.convert_image_dtype(x, dtype=tf.float32), y)
train_ds = (train_ds_
.map(cast_img)
.cache()
.prefetch(buffer_size=AUTO))
val_ds = (val_ds_
.map(cast_img)
.cache()
.prefetch(buffer_size=AUTO))
test_ds = (test_ds_
.map(cast_img)
.cache()
.prefetch(buffer_size=AUTO))
def plot_scores(hist, metric='accuracy', finetune=None, clamp_loss=None, smooth=0):
y_lim = min(4., np.max(hist['loss']), np.max(hist['val_loss']))
fig, axs = plt.subplots(1,2,figsize=(14,4))
x=range(1,len(hist['loss'])+1)
axs[0].plot(x,hist['loss'], label='loss')
axs[0].plot(x,hist['val_loss'], label='val_loss')
if clamp_loss:
axs[0].set_ylim(*clamp_loss)
axs[1].plot(x, hist[metric], label=metric)
axs[1].plot(x, hist['val_'+metric], label='val_'+metric)
if smooth > 0:
axs[1].plot(x[smooth:],
np.convolve(hist['val_'+metric], np.ones(smooth)/smooth, mode='full')[smooth:len(x)],
label='val_'+metric+'_smooth', c='#7f7f7f', linestyle='--')
for ax in axs:
if finetune:
ax.plot([finetune, finetune], ax.get_ylim(), label='start finetuning',scaley=False)
ax.legend()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_xlabel('Epoch')
plt.show()
We first train a simple sequential CNN architecture from scratch. It is based on the VGG16 network (see Configuration A in Table 1), but we reduce the size of the fully connected layers since our dataset is small and we only want to perform binary classification. To alleviate problems with high variance due to the small training set, we include regularisation in the form of some simple image augmentation steps and BatchNormalisation layers after each convolutional layer. As usual we also include 0.5 dropout for the fully connected top layers.
We use the Adam optimiser with learning_rate=epsilon=0.001 for training. We see that after ~40 epochs we obtain a 70.3% accuracy which is roughly a 9 percent point improvement over our baseline at the cost of a much more complicated model.
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode='horizontal'),
layers.experimental.preprocessing.RandomRotation(0.2),
layers.experimental.preprocessing.RandomTranslation(0.1, 0.1),
layers.experimental.preprocessing.RandomZoom(0.2)
])
model_simple = tf.keras.Sequential([
layers.InputLayer(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS)),
data_augmentation,
layers.experimental.preprocessing.RandomContrast(0.2),
layers.experimental.preprocessing.Rescaling(scale=1./255),
# Block 1
layers.Conv2D(64, 3, activation='relu', padding='same'),
layers.BatchNormalization(),
layers.MaxPool2D(),
# Block 2
layers.Conv2D(128, 3, activation='relu', padding='same'),
layers.BatchNormalization(),
layers.MaxPool2D(),
# Block 3
layers.Conv2D(256, 3, activation='relu', padding='same'),
layers.BatchNormalization(),
layers.Conv2D(256, 3, activation='relu', padding='same'),
layers.BatchNormalization(),
layers.MaxPool2D(),
# Block 4
layers.Conv2D(512, 3, activation='relu', padding='same'),
layers.BatchNormalization(),
layers.Conv2D(512, 3, activation='relu', padding='same'),
layers.BatchNormalization(),
layers.MaxPool2D(),
# Block 5
layers.Conv2D(512, 3, activation='relu', padding='same'),
layers.BatchNormalization(),
layers.Conv2D(512, 3, activation='relu', padding='same'),
layers.BatchNormalization(),
layers.MaxPool2D(),
layers.Flatten(),
layers.Dense(512, activation='relu'),
layers.Dropout(0.5),
layers.Dense(512, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid'),
])
print(f"{model_simple.count_params()/1e6:.1f} M parameters")
22.3 M parameters
base_lr = 0.001
early_stop = tf.keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True)
callbacks = [early_stop]
model_simple.compile(
optimizer=tf.keras.optimizers.Adam(lr=base_lr, epsilon=1e-3),
loss='binary_crossentropy',
metrics=['binary_accuracy']
)
history_simple = model_simple.fit(
train_ds,
validation_data=val_ds,
epochs=200,
callbacks=[callbacks],
verbose=True
)
Epoch 1/200 54/54 [==============================] - 72s 1s/step - loss: 2.6941 - binary_accuracy: 0.5306 - val_loss: 4.7720 - val_binary_accuracy: 0.4847 Epoch 2/200 54/54 [==============================] - 30s 559ms/step - loss: 3.0858 - binary_accuracy: 0.5237 - val_loss: 0.8158 - val_binary_accuracy: 0.5224 Epoch 3/200 54/54 [==============================] - 29s 545ms/step - loss: 2.2368 - binary_accuracy: 0.5348 - val_loss: 0.7357 - val_binary_accuracy: 0.5294 Epoch 4/200 54/54 [==============================] - 30s 553ms/step - loss: 1.5296 - binary_accuracy: 0.5262 - val_loss: 0.8043 - val_binary_accuracy: 0.4871 Epoch 5/200 54/54 [==============================] - 30s 557ms/step - loss: 1.4557 - binary_accuracy: 0.4861 - val_loss: 0.7116 - val_binary_accuracy: 0.4965 Epoch 6/200 54/54 [==============================] - 29s 537ms/step - loss: 1.0744 - binary_accuracy: 0.5536 - val_loss: 0.7192 - val_binary_accuracy: 0.5153 Epoch 7/200 54/54 [==============================] - 31s 578ms/step - loss: 1.0820 - binary_accuracy: 0.5592 - val_loss: 0.7314 - val_binary_accuracy: 0.4612 Epoch 8/200 54/54 [==============================] - 31s 582ms/step - loss: 0.9724 - binary_accuracy: 0.5342 - val_loss: 0.6741 - val_binary_accuracy: 0.5835 Epoch 9/200 54/54 [==============================] - 27s 498ms/step - loss: 0.8364 - binary_accuracy: 0.5576 - val_loss: 0.6797 - val_binary_accuracy: 0.6024 Epoch 10/200 54/54 [==============================] - 27s 494ms/step - loss: 0.7633 - binary_accuracy: 0.5621 - val_loss: 0.6603 - val_binary_accuracy: 0.6400 Epoch 11/200 54/54 [==============================] - 27s 494ms/step - loss: 0.7768 - binary_accuracy: 0.5507 - val_loss: 0.6496 - val_binary_accuracy: 0.6259 Epoch 12/200 54/54 [==============================] - 27s 493ms/step - loss: 0.7309 - binary_accuracy: 0.6013 - val_loss: 0.6607 - val_binary_accuracy: 0.6282 Epoch 13/200 54/54 [==============================] - 27s 494ms/step - loss: 0.6989 - binary_accuracy: 0.5970 - val_loss: 0.6554 - val_binary_accuracy: 0.6635 Epoch 14/200 54/54 [==============================] - 27s 494ms/step - loss: 0.7139 - binary_accuracy: 0.5939 - val_loss: 0.6491 - val_binary_accuracy: 0.6588 Epoch 15/200 54/54 [==============================] - 28s 522ms/step - loss: 0.6687 - binary_accuracy: 0.6178 - val_loss: 0.6749 - val_binary_accuracy: 0.5976 Epoch 16/200 54/54 [==============================] - 28s 516ms/step - loss: 0.6846 - binary_accuracy: 0.5999 - val_loss: 0.6467 - val_binary_accuracy: 0.6376 Epoch 17/200 54/54 [==============================] - 28s 517ms/step - loss: 0.6743 - binary_accuracy: 0.6159 - val_loss: 0.6491 - val_binary_accuracy: 0.6424 Epoch 18/200 54/54 [==============================] - 28s 518ms/step - loss: 0.6527 - binary_accuracy: 0.6206 - val_loss: 0.6342 - val_binary_accuracy: 0.6612 Epoch 19/200 54/54 [==============================] - 28s 520ms/step - loss: 0.6542 - binary_accuracy: 0.6302 - val_loss: 0.6367 - val_binary_accuracy: 0.6635 Epoch 20/200 54/54 [==============================] - 28s 520ms/step - loss: 0.6372 - binary_accuracy: 0.6380 - val_loss: 0.6517 - val_binary_accuracy: 0.6141 Epoch 21/200 54/54 [==============================] - 28s 523ms/step - loss: 0.6395 - binary_accuracy: 0.6389 - val_loss: 0.6417 - val_binary_accuracy: 0.6376 Epoch 22/200 54/54 [==============================] - 28s 520ms/step - loss: 0.6327 - binary_accuracy: 0.6363 - val_loss: 0.6450 - val_binary_accuracy: 0.6259 Epoch 23/200 54/54 [==============================] - 29s 533ms/step - loss: 0.6528 - binary_accuracy: 0.6328 - val_loss: 0.6134 - val_binary_accuracy: 0.6753 Epoch 24/200 54/54 [==============================] - 28s 517ms/step - loss: 0.6324 - binary_accuracy: 0.6392 - val_loss: 0.6348 - val_binary_accuracy: 0.6800 Epoch 25/200 54/54 [==============================] - 27s 504ms/step - loss: 0.6056 - binary_accuracy: 0.6721 - val_loss: 0.6067 - val_binary_accuracy: 0.6635 Epoch 26/200 54/54 [==============================] - 27s 495ms/step - loss: 0.6439 - binary_accuracy: 0.6722 - val_loss: 0.6473 - val_binary_accuracy: 0.6306 Epoch 27/200 54/54 [==============================] - 27s 508ms/step - loss: 0.6273 - binary_accuracy: 0.6511 - val_loss: 0.6322 - val_binary_accuracy: 0.6612 Epoch 28/200 54/54 [==============================] - 27s 509ms/step - loss: 0.6244 - binary_accuracy: 0.6694 - val_loss: 0.6338 - val_binary_accuracy: 0.6376 Epoch 29/200 54/54 [==============================] - 27s 501ms/step - loss: 0.6196 - binary_accuracy: 0.6683 - val_loss: 0.6340 - val_binary_accuracy: 0.6447 Epoch 30/200 54/54 [==============================] - 27s 501ms/step - loss: 0.6011 - binary_accuracy: 0.6912 - val_loss: 0.6839 - val_binary_accuracy: 0.5435 Epoch 31/200 54/54 [==============================] - 27s 501ms/step - loss: 0.6086 - binary_accuracy: 0.6748 - val_loss: 0.6292 - val_binary_accuracy: 0.6518 Epoch 32/200 54/54 [==============================] - 27s 504ms/step - loss: 0.6061 - binary_accuracy: 0.6672 - val_loss: 0.5984 - val_binary_accuracy: 0.6824 Epoch 33/200 54/54 [==============================] - 27s 500ms/step - loss: 0.6040 - binary_accuracy: 0.6798 - val_loss: 0.6459 - val_binary_accuracy: 0.5976 Epoch 34/200 54/54 [==============================] - 27s 509ms/step - loss: 0.6004 - binary_accuracy: 0.6876 - val_loss: 0.6383 - val_binary_accuracy: 0.6212 Epoch 35/200 54/54 [==============================] - 27s 503ms/step - loss: 0.5984 - binary_accuracy: 0.6921 - val_loss: 0.6231 - val_binary_accuracy: 0.6494 Epoch 36/200 54/54 [==============================] - 27s 505ms/step - loss: 0.6108 - binary_accuracy: 0.6807 - val_loss: 0.6349 - val_binary_accuracy: 0.6282 Epoch 37/200 54/54 [==============================] - 27s 502ms/step - loss: 0.5770 - binary_accuracy: 0.6768 - val_loss: 0.6160 - val_binary_accuracy: 0.6518 Epoch 38/200 54/54 [==============================] - 27s 498ms/step - loss: 0.5988 - binary_accuracy: 0.6744 - val_loss: 0.6411 - val_binary_accuracy: 0.6306 Epoch 39/200 54/54 [==============================] - 27s 496ms/step - loss: 0.5956 - binary_accuracy: 0.6790 - val_loss: 0.6137 - val_binary_accuracy: 0.6800 Epoch 40/200 54/54 [==============================] - 27s 500ms/step - loss: 0.6146 - binary_accuracy: 0.6798 - val_loss: 0.6805 - val_binary_accuracy: 0.5976 Epoch 41/200 54/54 [==============================] - 27s 500ms/step - loss: 0.5888 - binary_accuracy: 0.6906 - val_loss: 0.5794 - val_binary_accuracy: 0.7035 Epoch 42/200 54/54 [==============================] - 27s 498ms/step - loss: 0.5832 - binary_accuracy: 0.6967 - val_loss: 0.5936 - val_binary_accuracy: 0.7035 Epoch 43/200 54/54 [==============================] - 27s 500ms/step - loss: 0.5809 - binary_accuracy: 0.6955 - val_loss: 0.5951 - val_binary_accuracy: 0.6800 Epoch 44/200 54/54 [==============================] - 27s 505ms/step - loss: 0.6231 - binary_accuracy: 0.6506 - val_loss: 0.5969 - val_binary_accuracy: 0.6988 Epoch 45/200 54/54 [==============================] - 28s 518ms/step - loss: 0.5940 - binary_accuracy: 0.7156 - val_loss: 0.6662 - val_binary_accuracy: 0.5976 Epoch 46/200 54/54 [==============================] - 27s 498ms/step - loss: 0.5850 - binary_accuracy: 0.6976 - val_loss: 0.5930 - val_binary_accuracy: 0.6776 Epoch 47/200 54/54 [==============================] - 27s 505ms/step - loss: 0.5904 - binary_accuracy: 0.6906 - val_loss: 0.6019 - val_binary_accuracy: 0.6729 Epoch 48/200 54/54 [==============================] - 27s 508ms/step - loss: 0.5741 - binary_accuracy: 0.7046 - val_loss: 0.6345 - val_binary_accuracy: 0.6518 Epoch 49/200 54/54 [==============================] - 28s 510ms/step - loss: 0.5838 - binary_accuracy: 0.7189 - val_loss: 0.5963 - val_binary_accuracy: 0.6871 Epoch 50/200 54/54 [==============================] - 27s 504ms/step - loss: 0.5777 - binary_accuracy: 0.6920 - val_loss: 0.6520 - val_binary_accuracy: 0.5882 Epoch 51/200 54/54 [==============================] - 27s 499ms/step - loss: 0.5861 - binary_accuracy: 0.6753 - val_loss: 0.6010 - val_binary_accuracy: 0.6518 Epoch 52/200 54/54 [==============================] - 27s 506ms/step - loss: 0.5774 - binary_accuracy: 0.6933 - val_loss: 0.6093 - val_binary_accuracy: 0.6682 Epoch 53/200 54/54 [==============================] - 28s 509ms/step - loss: 0.5650 - binary_accuracy: 0.6812 - val_loss: 0.6243 - val_binary_accuracy: 0.6494 Epoch 54/200 54/54 [==============================] - 27s 508ms/step - loss: 0.5512 - binary_accuracy: 0.7096 - val_loss: 0.5947 - val_binary_accuracy: 0.6706 Epoch 55/200 54/54 [==============================] - 28s 511ms/step - loss: 0.5635 - binary_accuracy: 0.6925 - val_loss: 0.6160 - val_binary_accuracy: 0.6518 Epoch 56/200 54/54 [==============================] - 27s 508ms/step - loss: 0.5699 - binary_accuracy: 0.7068 - val_loss: 0.5975 - val_binary_accuracy: 0.6894 Epoch 57/200 54/54 [==============================] - 27s 502ms/step - loss: 0.5519 - binary_accuracy: 0.7105 - val_loss: 0.6159 - val_binary_accuracy: 0.6471 Epoch 58/200 54/54 [==============================] - 27s 497ms/step - loss: 0.5602 - binary_accuracy: 0.7084 - val_loss: 0.7632 - val_binary_accuracy: 0.5835 Epoch 59/200 54/54 [==============================] - 27s 501ms/step - loss: 0.5492 - binary_accuracy: 0.7198 - val_loss: 0.6268 - val_binary_accuracy: 0.6682 Epoch 60/200 54/54 [==============================] - 28s 516ms/step - loss: 0.5720 - binary_accuracy: 0.7171 - val_loss: 0.5991 - val_binary_accuracy: 0.6965 Epoch 61/200 54/54 [==============================] - 28s 514ms/step - loss: 0.5881 - binary_accuracy: 0.7049 - val_loss: 0.6222 - val_binary_accuracy: 0.6235 Epoch 62/200 54/54 [==============================] - 27s 503ms/step - loss: 0.5542 - binary_accuracy: 0.7324 - val_loss: 0.5894 - val_binary_accuracy: 0.6824 Epoch 63/200 54/54 [==============================] - 27s 506ms/step - loss: 0.5297 - binary_accuracy: 0.7498 - val_loss: 0.6236 - val_binary_accuracy: 0.6400 Epoch 64/200 54/54 [==============================] - 27s 509ms/step - loss: 0.5420 - binary_accuracy: 0.7241 - val_loss: 0.6136 - val_binary_accuracy: 0.6706 Epoch 65/200 54/54 [==============================] - 27s 503ms/step - loss: 0.5589 - binary_accuracy: 0.7171 - val_loss: 0.6284 - val_binary_accuracy: 0.6329 Epoch 66/200 54/54 [==============================] - 28s 514ms/step - loss: 0.5321 - binary_accuracy: 0.7422 - val_loss: 0.6114 - val_binary_accuracy: 0.6612 Epoch 67/200 54/54 [==============================] - 27s 502ms/step - loss: 0.5223 - binary_accuracy: 0.7462 - val_loss: 0.5991 - val_binary_accuracy: 0.6894 Epoch 68/200 54/54 [==============================] - 27s 503ms/step - loss: 0.5357 - binary_accuracy: 0.7405 - val_loss: 0.6492 - val_binary_accuracy: 0.6024 Epoch 69/200 54/54 [==============================] - 27s 500ms/step - loss: 0.5412 - binary_accuracy: 0.7248 - val_loss: 0.6291 - val_binary_accuracy: 0.6282 Epoch 70/200 54/54 [==============================] - 27s 498ms/step - loss: 0.5512 - binary_accuracy: 0.7212 - val_loss: 0.7413 - val_binary_accuracy: 0.5647 Epoch 71/200 54/54 [==============================] - 27s 507ms/step - loss: 0.5411 - binary_accuracy: 0.7419 - val_loss: 0.7175 - val_binary_accuracy: 0.5435
simple_hist = history_simple.history
plot_scores(simple_hist, 'binary_accuracy', clamp_loss=(0.4, 1), smooth=10)
model_simple.evaluate(val_ds, return_dict=True)
14/14 [==============================] - 2s 130ms/step - loss: 0.5794 - binary_accuracy: 0.7035
{'loss': 0.5794425010681152, 'binary_accuracy': 0.703529417514801}
The learning curves show that after about 50 epochs the model starts overfitting, but we recover the earlier model with less variance with early stopping. It seems plausible that with some more regularisation (e.g. adding dropout to each convolutional block) we could keep training and obtain improvements in accuracy. On the other hand, looking at the smoothed validation accuracy it is highly possible that the final 70% accuracy is just noise and the true (generalisation) accuracy of the model is somewhat lower. In the next section we will replace the convolutional part of our model with an existing network pretrained on the ImageNet dataset.
In order to improve on our sequential network, we next consider more sophisticated structures. Also, instead of training from scratch we use transfer learning with an existing network trained on the massive ImageNet dataset. Our first choice for such an architecture is ResNet, originally introduced by He et al. (2015) (our particular version of this network is tf.keras.applications.resnet_v2.ResNet50V2, which is a moderate improvement over the original ResNet50 structure). In this seminal paper the authors introduced so-called skip connections, where the outputs of a convolutional block completely bypass the immediately following block and are directly added to the input of the subsequent block. On the one hand, this has a regularising effect allowing them to train much deeper networks than e.g. VGG since gradients are free to travel (without decaying) through the identity mappings of the skip connections. On the other hand, the deeper layers now have access to finer spatial info (in the form of features from earlier convolutional blocks with smaller fibres) allowing them to learn more complex features. All in all, ResNets vastly outperform the traditional convolutional networks.
Typically, when doing transfer learning, one extracts the trained bottom layers of the pretrained model and discards the final fully connected top layers. A new set of fully connected layers, customised to the particular task, is then attached to this pretrained base and only the weights of the new layers are modified during training. Since ImageNet-trained models learn a very rich set of features, we regularise our fully connected layers with a typical 0.5 dropout.
With this method we're able to obtain 88.4% accuracy in roughly 15 training epochs. This is not surprising since the ImageNet dataset already contains many pictures of dogs so the model is used to detecting them. The downside of our model is that it is rather huge both in terms of disk space and memory requirements at inference time. This unfortunately makes our initial attempt unwieldy for the purpose of deploying it directly to Heroku with Flask. We could circumvent this by setting up a TFServing server, for example, but typically it's more desirable to trim the "unnecessary fat" from models that we want deployed to production.
pretrained_base = tf.keras.applications.ResNet50V2(include_top=False,
weights='imagenet')
pretrained_base.trainable = False
preprocess_input_resnet = tf.keras.applications.resnet_v2.preprocess_input
model_resnet = tf.keras.Sequential([
layers.InputLayer(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS)),
data_augmentation,
layers.Lambda(preprocess_input_resnet),
pretrained_base,
layers.Flatten(),
layers.Dense(512),
layers.Dropout(0.5),
layers.Dense(512),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid')
])
base_learning_rate = 0.001
model_resnet.compile(
optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['binary_accuracy'],
)
history_resnet = model_resnet.fit(
train_ds,
validation_data=val_ds,
epochs=30,
verbose=False
)
model_resnet.evaluate(val_ds, return_dict=True)
14/14 [==============================] - 1s 104ms/step - loss: 1.6888 - binary_accuracy: 0.8847
{'loss': 1.688774824142456, 'binary_accuracy': 0.8847059011459351}
resnet_hist = history_resnet.history
plot_scores(resnet_hist, 'binary_accuracy')
We clearly see how the model manages to avoid overfitting with help from the dropout layers and converges to a stable state. Let's see what the model predicts for our test pictures.
predictions = (model_resnet.predict(test_ds) > 0.5).ravel().astype(int)
plt.figure(figsize=(10, 10))
for image, label in test_ds.take(1):
for i in range(len(predictions)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image[i].numpy().astype("uint8"))
plt.title(f"{class_names[predictions[i]]}")
plt.axis(False)
All our family dogs got appropriately classified! In the next section we decrease the size of our model by using a more recent architecture tailor-made for mobile and edge devices. It contains vastly less weights, despite offering comparable (or even better) performance. Let's see whether this will be enough for our purposes.
EfficientNet (Tan and Le, 2020) was, at the time of its introduction, somewhat of a culmination of the evolution of multiple convolutional architectures. It did not introduce new ways to arrange layers or move around the network, per se, but focused more on various aspects of optimisation that had fallen to the sidelines with new innovative CNN models appearing almost every year. Their key idea was to realise that instead of treating existing architectures as fixed, they should be scaled with various parameters (depth, width and resolution) in a controlled manner based on the task at hand. In the paper the authors applied this technique to many existing networks drastically reducing the number of parameters while retaining comparative performance. In particular, they introduced a series of optimised networks, called EfficientNetB0-EfficientNetB7, where the complexity grows from B0 to B7 and each network had been tuned with their methodology.
For our problem we pick one of the smaller networks, namely, B1. To further reduce the complexity, we replace the fully connected final layers that we used with ResNet by global average pooling. Global average pooling was introduced by Lin et al. as a part of their paper Network in Network (2014). This type of layer simply compresses each filter from the previous convolutional layer by computing its mean. The intuition for this was that the pooling layer retains some of the original spatial structure which would normally be lost after flattening. Moreover, such a layer has no weights to train (because we are just taking an average over each individual filter) which helps against overfitting. GAP also makes the network smaller, because usually it is the final fully connected layers of a CNN which contain the majority of the model's weights.
def build_model(base_lr = 0.001):
base_model_effnet = tf.keras.applications.efficientnet.EfficientNetB1(include_top=False,
weights='imagenet')
base_model_effnet.trainable = False
preprocess_input_effnet = tf.keras.applications.efficientnet.preprocess_input
model = tf.keras.Sequential([
layers.InputLayer(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS)),
data_augmentation,
layers.Lambda(preprocess_input_effnet),
base_model_effnet,
layers.GlobalAveragePooling2D(),
layers.Dropout(0.2),
layers.Dense(1),
layers.Activation('sigmoid')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(lr=base_lr),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['binary_accuracy']
)
return model
model_effnet = build_model()
lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(patience=7)
early_stopping = tf.keras.callbacks.EarlyStopping(patience=13, restore_best_weights=True)
callback = [lr_reduce, early_stopping]
history_effnet = model_effnet.fit(
train_ds,
validation_data=val_ds,
epochs=80,
callbacks=callback,
verbose=False
)
model_effnet.evaluate(val_ds, return_dict=True)
14/14 [==============================] - 2s 116ms/step - loss: 0.1846 - binary_accuracy: 0.9318
{'loss': 0.18463800847530365, 'binary_accuracy': 0.9317647218704224}
effnet_hist = history_effnet.history
plot_scores(effnet_hist, 'binary_accuracy')
We obtain a good improvement over what we had with ResNet. Notice that we also train for a lot longer even though the validation accuracy quickly reaches a fairly stable level around 93%. This is because we want to attempt to further increase the accuracy by fine-tuning our model and for this it's important that the top layers have fully converged before we unfreeze any weights. If the final layers were unstable we would risk throwing the pretrained weights of the newly unfrozen layers off-balance as they would have to adjust to large backpropagated gradients.
We can see below that the EfficientNet network consists of several stages, which each contain a number of convolutional blocks (see Table 1 in the paper). We will unfreeze only the last block of the last stage and train with a much smaller learning rate of $10^{-6}$.
model_effnet.layers[2].summary()
Model: "efficientnetb1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
rescaling (Rescaling) (None, None, None, 3 0 input_1[0][0]
__________________________________________________________________________________________________
normalization (Normalization) (None, None, None, 3 7 rescaling[0][0]
__________________________________________________________________________________________________
stem_conv_pad (ZeroPadding2D) (None, None, None, 3 0 normalization[0][0]
__________________________________________________________________________________________________
stem_conv (Conv2D) (None, None, None, 3 864 stem_conv_pad[0][0]
__________________________________________________________________________________________________
stem_bn (BatchNormalization) (None, None, None, 3 128 stem_conv[0][0]
__________________________________________________________________________________________________
stem_activation (Activation) (None, None, None, 3 0 stem_bn[0][0]
__________________________________________________________________________________________________
block1a_dwconv (DepthwiseConv2D (None, None, None, 3 288 stem_activation[0][0]
__________________________________________________________________________________________________
block1a_bn (BatchNormalization) (None, None, None, 3 128 block1a_dwconv[0][0]
__________________________________________________________________________________________________
block1a_activation (Activation) (None, None, None, 3 0 block1a_bn[0][0]
__________________________________________________________________________________________________
block1a_se_squeeze (GlobalAvera (None, 32) 0 block1a_activation[0][0]
__________________________________________________________________________________________________
block1a_se_reshape (Reshape) (None, 1, 1, 32) 0 block1a_se_squeeze[0][0]
__________________________________________________________________________________________________
block1a_se_reduce (Conv2D) (None, 1, 1, 8) 264 block1a_se_reshape[0][0]
__________________________________________________________________________________________________
block1a_se_expand (Conv2D) (None, 1, 1, 32) 288 block1a_se_reduce[0][0]
__________________________________________________________________________________________________
block1a_se_excite (Multiply) (None, None, None, 3 0 block1a_activation[0][0]
block1a_se_expand[0][0]
__________________________________________________________________________________________________
block1a_project_conv (Conv2D) (None, None, None, 1 512 block1a_se_excite[0][0]
__________________________________________________________________________________________________
block1a_project_bn (BatchNormal (None, None, None, 1 64 block1a_project_conv[0][0]
__________________________________________________________________________________________________
block1b_dwconv (DepthwiseConv2D (None, None, None, 1 144 block1a_project_bn[0][0]
__________________________________________________________________________________________________
block1b_bn (BatchNormalization) (None, None, None, 1 64 block1b_dwconv[0][0]
__________________________________________________________________________________________________
block1b_activation (Activation) (None, None, None, 1 0 block1b_bn[0][0]
__________________________________________________________________________________________________
block1b_se_squeeze (GlobalAvera (None, 16) 0 block1b_activation[0][0]
__________________________________________________________________________________________________
block1b_se_reshape (Reshape) (None, 1, 1, 16) 0 block1b_se_squeeze[0][0]
__________________________________________________________________________________________________
block1b_se_reduce (Conv2D) (None, 1, 1, 4) 68 block1b_se_reshape[0][0]
__________________________________________________________________________________________________
block1b_se_expand (Conv2D) (None, 1, 1, 16) 80 block1b_se_reduce[0][0]
__________________________________________________________________________________________________
block1b_se_excite (Multiply) (None, None, None, 1 0 block1b_activation[0][0]
block1b_se_expand[0][0]
__________________________________________________________________________________________________
block1b_project_conv (Conv2D) (None, None, None, 1 256 block1b_se_excite[0][0]
__________________________________________________________________________________________________
block1b_project_bn (BatchNormal (None, None, None, 1 64 block1b_project_conv[0][0]
__________________________________________________________________________________________________
block1b_drop (Dropout) (None, None, None, 1 0 block1b_project_bn[0][0]
__________________________________________________________________________________________________
block1b_add (Add) (None, None, None, 1 0 block1b_drop[0][0]
block1a_project_bn[0][0]
__________________________________________________________________________________________________
block2a_expand_conv (Conv2D) (None, None, None, 9 1536 block1b_add[0][0]
__________________________________________________________________________________________________
block2a_expand_bn (BatchNormali (None, None, None, 9 384 block2a_expand_conv[0][0]
__________________________________________________________________________________________________
block2a_expand_activation (Acti (None, None, None, 9 0 block2a_expand_bn[0][0]
__________________________________________________________________________________________________
block2a_dwconv_pad (ZeroPadding (None, None, None, 9 0 block2a_expand_activation[0][0]
__________________________________________________________________________________________________
block2a_dwconv (DepthwiseConv2D (None, None, None, 9 864 block2a_dwconv_pad[0][0]
__________________________________________________________________________________________________
block2a_bn (BatchNormalization) (None, None, None, 9 384 block2a_dwconv[0][0]
__________________________________________________________________________________________________
block2a_activation (Activation) (None, None, None, 9 0 block2a_bn[0][0]
__________________________________________________________________________________________________
block2a_se_squeeze (GlobalAvera (None, 96) 0 block2a_activation[0][0]
__________________________________________________________________________________________________
block2a_se_reshape (Reshape) (None, 1, 1, 96) 0 block2a_se_squeeze[0][0]
__________________________________________________________________________________________________
block2a_se_reduce (Conv2D) (None, 1, 1, 4) 388 block2a_se_reshape[0][0]
__________________________________________________________________________________________________
block2a_se_expand (Conv2D) (None, 1, 1, 96) 480 block2a_se_reduce[0][0]
__________________________________________________________________________________________________
block2a_se_excite (Multiply) (None, None, None, 9 0 block2a_activation[0][0]
block2a_se_expand[0][0]
__________________________________________________________________________________________________
block2a_project_conv (Conv2D) (None, None, None, 2 2304 block2a_se_excite[0][0]
__________________________________________________________________________________________________
block2a_project_bn (BatchNormal (None, None, None, 2 96 block2a_project_conv[0][0]
__________________________________________________________________________________________________
block2b_expand_conv (Conv2D) (None, None, None, 1 3456 block2a_project_bn[0][0]
__________________________________________________________________________________________________
block2b_expand_bn (BatchNormali (None, None, None, 1 576 block2b_expand_conv[0][0]
__________________________________________________________________________________________________
block2b_expand_activation (Acti (None, None, None, 1 0 block2b_expand_bn[0][0]
__________________________________________________________________________________________________
block2b_dwconv (DepthwiseConv2D (None, None, None, 1 1296 block2b_expand_activation[0][0]
__________________________________________________________________________________________________
block2b_bn (BatchNormalization) (None, None, None, 1 576 block2b_dwconv[0][0]
__________________________________________________________________________________________________
block2b_activation (Activation) (None, None, None, 1 0 block2b_bn[0][0]
__________________________________________________________________________________________________
block2b_se_squeeze (GlobalAvera (None, 144) 0 block2b_activation[0][0]
__________________________________________________________________________________________________
block2b_se_reshape (Reshape) (None, 1, 1, 144) 0 block2b_se_squeeze[0][0]
__________________________________________________________________________________________________
block2b_se_reduce (Conv2D) (None, 1, 1, 6) 870 block2b_se_reshape[0][0]
__________________________________________________________________________________________________
block2b_se_expand (Conv2D) (None, 1, 1, 144) 1008 block2b_se_reduce[0][0]
__________________________________________________________________________________________________
block2b_se_excite (Multiply) (None, None, None, 1 0 block2b_activation[0][0]
block2b_se_expand[0][0]
__________________________________________________________________________________________________
block2b_project_conv (Conv2D) (None, None, None, 2 3456 block2b_se_excite[0][0]
__________________________________________________________________________________________________
block2b_project_bn (BatchNormal (None, None, None, 2 96 block2b_project_conv[0][0]
__________________________________________________________________________________________________
block2b_drop (Dropout) (None, None, None, 2 0 block2b_project_bn[0][0]
__________________________________________________________________________________________________
block2b_add (Add) (None, None, None, 2 0 block2b_drop[0][0]
block2a_project_bn[0][0]
__________________________________________________________________________________________________
block2c_expand_conv (Conv2D) (None, None, None, 1 3456 block2b_add[0][0]
__________________________________________________________________________________________________
block2c_expand_bn (BatchNormali (None, None, None, 1 576 block2c_expand_conv[0][0]
__________________________________________________________________________________________________
block2c_expand_activation (Acti (None, None, None, 1 0 block2c_expand_bn[0][0]
__________________________________________________________________________________________________
block2c_dwconv (DepthwiseConv2D (None, None, None, 1 1296 block2c_expand_activation[0][0]
__________________________________________________________________________________________________
block2c_bn (BatchNormalization) (None, None, None, 1 576 block2c_dwconv[0][0]
__________________________________________________________________________________________________
block2c_activation (Activation) (None, None, None, 1 0 block2c_bn[0][0]
__________________________________________________________________________________________________
block2c_se_squeeze (GlobalAvera (None, 144) 0 block2c_activation[0][0]
__________________________________________________________________________________________________
block2c_se_reshape (Reshape) (None, 1, 1, 144) 0 block2c_se_squeeze[0][0]
__________________________________________________________________________________________________
block2c_se_reduce (Conv2D) (None, 1, 1, 6) 870 block2c_se_reshape[0][0]
__________________________________________________________________________________________________
block2c_se_expand (Conv2D) (None, 1, 1, 144) 1008 block2c_se_reduce[0][0]
__________________________________________________________________________________________________
block2c_se_excite (Multiply) (None, None, None, 1 0 block2c_activation[0][0]
block2c_se_expand[0][0]
__________________________________________________________________________________________________
block2c_project_conv (Conv2D) (None, None, None, 2 3456 block2c_se_excite[0][0]
__________________________________________________________________________________________________
block2c_project_bn (BatchNormal (None, None, None, 2 96 block2c_project_conv[0][0]
__________________________________________________________________________________________________
block2c_drop (Dropout) (None, None, None, 2 0 block2c_project_bn[0][0]
__________________________________________________________________________________________________
block2c_add (Add) (None, None, None, 2 0 block2c_drop[0][0]
block2b_add[0][0]
__________________________________________________________________________________________________
block3a_expand_conv (Conv2D) (None, None, None, 1 3456 block2c_add[0][0]
__________________________________________________________________________________________________
block3a_expand_bn (BatchNormali (None, None, None, 1 576 block3a_expand_conv[0][0]
__________________________________________________________________________________________________
block3a_expand_activation (Acti (None, None, None, 1 0 block3a_expand_bn[0][0]
__________________________________________________________________________________________________
block3a_dwconv_pad (ZeroPadding (None, None, None, 1 0 block3a_expand_activation[0][0]
__________________________________________________________________________________________________
block3a_dwconv (DepthwiseConv2D (None, None, None, 1 3600 block3a_dwconv_pad[0][0]
__________________________________________________________________________________________________
block3a_bn (BatchNormalization) (None, None, None, 1 576 block3a_dwconv[0][0]
__________________________________________________________________________________________________
block3a_activation (Activation) (None, None, None, 1 0 block3a_bn[0][0]
__________________________________________________________________________________________________
block3a_se_squeeze (GlobalAvera (None, 144) 0 block3a_activation[0][0]
__________________________________________________________________________________________________
block3a_se_reshape (Reshape) (None, 1, 1, 144) 0 block3a_se_squeeze[0][0]
__________________________________________________________________________________________________
block3a_se_reduce (Conv2D) (None, 1, 1, 6) 870 block3a_se_reshape[0][0]
__________________________________________________________________________________________________
block3a_se_expand (Conv2D) (None, 1, 1, 144) 1008 block3a_se_reduce[0][0]
__________________________________________________________________________________________________
block3a_se_excite (Multiply) (None, None, None, 1 0 block3a_activation[0][0]
block3a_se_expand[0][0]
__________________________________________________________________________________________________
block3a_project_conv (Conv2D) (None, None, None, 4 5760 block3a_se_excite[0][0]
__________________________________________________________________________________________________
block3a_project_bn (BatchNormal (None, None, None, 4 160 block3a_project_conv[0][0]
__________________________________________________________________________________________________
block3b_expand_conv (Conv2D) (None, None, None, 2 9600 block3a_project_bn[0][0]
__________________________________________________________________________________________________
block3b_expand_bn (BatchNormali (None, None, None, 2 960 block3b_expand_conv[0][0]
__________________________________________________________________________________________________
block3b_expand_activation (Acti (None, None, None, 2 0 block3b_expand_bn[0][0]
__________________________________________________________________________________________________
block3b_dwconv (DepthwiseConv2D (None, None, None, 2 6000 block3b_expand_activation[0][0]
__________________________________________________________________________________________________
block3b_bn (BatchNormalization) (None, None, None, 2 960 block3b_dwconv[0][0]
__________________________________________________________________________________________________
block3b_activation (Activation) (None, None, None, 2 0 block3b_bn[0][0]
__________________________________________________________________________________________________
block3b_se_squeeze (GlobalAvera (None, 240) 0 block3b_activation[0][0]
__________________________________________________________________________________________________
block3b_se_reshape (Reshape) (None, 1, 1, 240) 0 block3b_se_squeeze[0][0]
__________________________________________________________________________________________________
block3b_se_reduce (Conv2D) (None, 1, 1, 10) 2410 block3b_se_reshape[0][0]
__________________________________________________________________________________________________
block3b_se_expand (Conv2D) (None, 1, 1, 240) 2640 block3b_se_reduce[0][0]
__________________________________________________________________________________________________
block3b_se_excite (Multiply) (None, None, None, 2 0 block3b_activation[0][0]
block3b_se_expand[0][0]
__________________________________________________________________________________________________
block3b_project_conv (Conv2D) (None, None, None, 4 9600 block3b_se_excite[0][0]
__________________________________________________________________________________________________
block3b_project_bn (BatchNormal (None, None, None, 4 160 block3b_project_conv[0][0]
__________________________________________________________________________________________________
block3b_drop (Dropout) (None, None, None, 4 0 block3b_project_bn[0][0]
__________________________________________________________________________________________________
block3b_add (Add) (None, None, None, 4 0 block3b_drop[0][0]
block3a_project_bn[0][0]
__________________________________________________________________________________________________
block3c_expand_conv (Conv2D) (None, None, None, 2 9600 block3b_add[0][0]
__________________________________________________________________________________________________
block3c_expand_bn (BatchNormali (None, None, None, 2 960 block3c_expand_conv[0][0]
__________________________________________________________________________________________________
block3c_expand_activation (Acti (None, None, None, 2 0 block3c_expand_bn[0][0]
__________________________________________________________________________________________________
block3c_dwconv (DepthwiseConv2D (None, None, None, 2 6000 block3c_expand_activation[0][0]
__________________________________________________________________________________________________
block3c_bn (BatchNormalization) (None, None, None, 2 960 block3c_dwconv[0][0]
__________________________________________________________________________________________________
block3c_activation (Activation) (None, None, None, 2 0 block3c_bn[0][0]
__________________________________________________________________________________________________
block3c_se_squeeze (GlobalAvera (None, 240) 0 block3c_activation[0][0]
__________________________________________________________________________________________________
block3c_se_reshape (Reshape) (None, 1, 1, 240) 0 block3c_se_squeeze[0][0]
__________________________________________________________________________________________________
block3c_se_reduce (Conv2D) (None, 1, 1, 10) 2410 block3c_se_reshape[0][0]
__________________________________________________________________________________________________
block3c_se_expand (Conv2D) (None, 1, 1, 240) 2640 block3c_se_reduce[0][0]
__________________________________________________________________________________________________
block3c_se_excite (Multiply) (None, None, None, 2 0 block3c_activation[0][0]
block3c_se_expand[0][0]
__________________________________________________________________________________________________
block3c_project_conv (Conv2D) (None, None, None, 4 9600 block3c_se_excite[0][0]
__________________________________________________________________________________________________
block3c_project_bn (BatchNormal (None, None, None, 4 160 block3c_project_conv[0][0]
__________________________________________________________________________________________________
block3c_drop (Dropout) (None, None, None, 4 0 block3c_project_bn[0][0]
__________________________________________________________________________________________________
block3c_add (Add) (None, None, None, 4 0 block3c_drop[0][0]
block3b_add[0][0]
__________________________________________________________________________________________________
block4a_expand_conv (Conv2D) (None, None, None, 2 9600 block3c_add[0][0]
__________________________________________________________________________________________________
block4a_expand_bn (BatchNormali (None, None, None, 2 960 block4a_expand_conv[0][0]
__________________________________________________________________________________________________
block4a_expand_activation (Acti (None, None, None, 2 0 block4a_expand_bn[0][0]
__________________________________________________________________________________________________
block4a_dwconv_pad (ZeroPadding (None, None, None, 2 0 block4a_expand_activation[0][0]
__________________________________________________________________________________________________
block4a_dwconv (DepthwiseConv2D (None, None, None, 2 2160 block4a_dwconv_pad[0][0]
__________________________________________________________________________________________________
block4a_bn (BatchNormalization) (None, None, None, 2 960 block4a_dwconv[0][0]
__________________________________________________________________________________________________
block4a_activation (Activation) (None, None, None, 2 0 block4a_bn[0][0]
__________________________________________________________________________________________________
block4a_se_squeeze (GlobalAvera (None, 240) 0 block4a_activation[0][0]
__________________________________________________________________________________________________
block4a_se_reshape (Reshape) (None, 1, 1, 240) 0 block4a_se_squeeze[0][0]
__________________________________________________________________________________________________
block4a_se_reduce (Conv2D) (None, 1, 1, 10) 2410 block4a_se_reshape[0][0]
__________________________________________________________________________________________________
block4a_se_expand (Conv2D) (None, 1, 1, 240) 2640 block4a_se_reduce[0][0]
__________________________________________________________________________________________________
block4a_se_excite (Multiply) (None, None, None, 2 0 block4a_activation[0][0]
block4a_se_expand[0][0]
__________________________________________________________________________________________________
block4a_project_conv (Conv2D) (None, None, None, 8 19200 block4a_se_excite[0][0]
__________________________________________________________________________________________________
block4a_project_bn (BatchNormal (None, None, None, 8 320 block4a_project_conv[0][0]
__________________________________________________________________________________________________
block4b_expand_conv (Conv2D) (None, None, None, 4 38400 block4a_project_bn[0][0]
__________________________________________________________________________________________________
block4b_expand_bn (BatchNormali (None, None, None, 4 1920 block4b_expand_conv[0][0]
__________________________________________________________________________________________________
block4b_expand_activation (Acti (None, None, None, 4 0 block4b_expand_bn[0][0]
__________________________________________________________________________________________________
block4b_dwconv (DepthwiseConv2D (None, None, None, 4 4320 block4b_expand_activation[0][0]
__________________________________________________________________________________________________
block4b_bn (BatchNormalization) (None, None, None, 4 1920 block4b_dwconv[0][0]
__________________________________________________________________________________________________
block4b_activation (Activation) (None, None, None, 4 0 block4b_bn[0][0]
__________________________________________________________________________________________________
block4b_se_squeeze (GlobalAvera (None, 480) 0 block4b_activation[0][0]
__________________________________________________________________________________________________
block4b_se_reshape (Reshape) (None, 1, 1, 480) 0 block4b_se_squeeze[0][0]
__________________________________________________________________________________________________
block4b_se_reduce (Conv2D) (None, 1, 1, 20) 9620 block4b_se_reshape[0][0]
__________________________________________________________________________________________________
block4b_se_expand (Conv2D) (None, 1, 1, 480) 10080 block4b_se_reduce[0][0]
__________________________________________________________________________________________________
block4b_se_excite (Multiply) (None, None, None, 4 0 block4b_activation[0][0]
block4b_se_expand[0][0]
__________________________________________________________________________________________________
block4b_project_conv (Conv2D) (None, None, None, 8 38400 block4b_se_excite[0][0]
__________________________________________________________________________________________________
block4b_project_bn (BatchNormal (None, None, None, 8 320 block4b_project_conv[0][0]
__________________________________________________________________________________________________
block4b_drop (Dropout) (None, None, None, 8 0 block4b_project_bn[0][0]
__________________________________________________________________________________________________
block4b_add (Add) (None, None, None, 8 0 block4b_drop[0][0]
block4a_project_bn[0][0]
__________________________________________________________________________________________________
block4c_expand_conv (Conv2D) (None, None, None, 4 38400 block4b_add[0][0]
__________________________________________________________________________________________________
block4c_expand_bn (BatchNormali (None, None, None, 4 1920 block4c_expand_conv[0][0]
__________________________________________________________________________________________________
block4c_expand_activation (Acti (None, None, None, 4 0 block4c_expand_bn[0][0]
__________________________________________________________________________________________________
block4c_dwconv (DepthwiseConv2D (None, None, None, 4 4320 block4c_expand_activation[0][0]
__________________________________________________________________________________________________
block4c_bn (BatchNormalization) (None, None, None, 4 1920 block4c_dwconv[0][0]
__________________________________________________________________________________________________
block4c_activation (Activation) (None, None, None, 4 0 block4c_bn[0][0]
__________________________________________________________________________________________________
block4c_se_squeeze (GlobalAvera (None, 480) 0 block4c_activation[0][0]
__________________________________________________________________________________________________
block4c_se_reshape (Reshape) (None, 1, 1, 480) 0 block4c_se_squeeze[0][0]
__________________________________________________________________________________________________
block4c_se_reduce (Conv2D) (None, 1, 1, 20) 9620 block4c_se_reshape[0][0]
__________________________________________________________________________________________________
block4c_se_expand (Conv2D) (None, 1, 1, 480) 10080 block4c_se_reduce[0][0]
__________________________________________________________________________________________________
block4c_se_excite (Multiply) (None, None, None, 4 0 block4c_activation[0][0]
block4c_se_expand[0][0]
__________________________________________________________________________________________________
block4c_project_conv (Conv2D) (None, None, None, 8 38400 block4c_se_excite[0][0]
__________________________________________________________________________________________________
block4c_project_bn (BatchNormal (None, None, None, 8 320 block4c_project_conv[0][0]
__________________________________________________________________________________________________
block4c_drop (Dropout) (None, None, None, 8 0 block4c_project_bn[0][0]
__________________________________________________________________________________________________
block4c_add (Add) (None, None, None, 8 0 block4c_drop[0][0]
block4b_add[0][0]
__________________________________________________________________________________________________
block4d_expand_conv (Conv2D) (None, None, None, 4 38400 block4c_add[0][0]
__________________________________________________________________________________________________
block4d_expand_bn (BatchNormali (None, None, None, 4 1920 block4d_expand_conv[0][0]
__________________________________________________________________________________________________
block4d_expand_activation (Acti (None, None, None, 4 0 block4d_expand_bn[0][0]
__________________________________________________________________________________________________
block4d_dwconv (DepthwiseConv2D (None, None, None, 4 4320 block4d_expand_activation[0][0]
__________________________________________________________________________________________________
block4d_bn (BatchNormalization) (None, None, None, 4 1920 block4d_dwconv[0][0]
__________________________________________________________________________________________________
block4d_activation (Activation) (None, None, None, 4 0 block4d_bn[0][0]
__________________________________________________________________________________________________
block4d_se_squeeze (GlobalAvera (None, 480) 0 block4d_activation[0][0]
__________________________________________________________________________________________________
block4d_se_reshape (Reshape) (None, 1, 1, 480) 0 block4d_se_squeeze[0][0]
__________________________________________________________________________________________________
block4d_se_reduce (Conv2D) (None, 1, 1, 20) 9620 block4d_se_reshape[0][0]
__________________________________________________________________________________________________
block4d_se_expand (Conv2D) (None, 1, 1, 480) 10080 block4d_se_reduce[0][0]
__________________________________________________________________________________________________
block4d_se_excite (Multiply) (None, None, None, 4 0 block4d_activation[0][0]
block4d_se_expand[0][0]
__________________________________________________________________________________________________
block4d_project_conv (Conv2D) (None, None, None, 8 38400 block4d_se_excite[0][0]
__________________________________________________________________________________________________
block4d_project_bn (BatchNormal (None, None, None, 8 320 block4d_project_conv[0][0]
__________________________________________________________________________________________________
block4d_drop (Dropout) (None, None, None, 8 0 block4d_project_bn[0][0]
__________________________________________________________________________________________________
block4d_add (Add) (None, None, None, 8 0 block4d_drop[0][0]
block4c_add[0][0]
__________________________________________________________________________________________________
block5a_expand_conv (Conv2D) (None, None, None, 4 38400 block4d_add[0][0]
__________________________________________________________________________________________________
block5a_expand_bn (BatchNormali (None, None, None, 4 1920 block5a_expand_conv[0][0]
__________________________________________________________________________________________________
block5a_expand_activation (Acti (None, None, None, 4 0 block5a_expand_bn[0][0]
__________________________________________________________________________________________________
block5a_dwconv (DepthwiseConv2D (None, None, None, 4 12000 block5a_expand_activation[0][0]
__________________________________________________________________________________________________
block5a_bn (BatchNormalization) (None, None, None, 4 1920 block5a_dwconv[0][0]
__________________________________________________________________________________________________
block5a_activation (Activation) (None, None, None, 4 0 block5a_bn[0][0]
__________________________________________________________________________________________________
block5a_se_squeeze (GlobalAvera (None, 480) 0 block5a_activation[0][0]
__________________________________________________________________________________________________
block5a_se_reshape (Reshape) (None, 1, 1, 480) 0 block5a_se_squeeze[0][0]
__________________________________________________________________________________________________
block5a_se_reduce (Conv2D) (None, 1, 1, 20) 9620 block5a_se_reshape[0][0]
__________________________________________________________________________________________________
block5a_se_expand (Conv2D) (None, 1, 1, 480) 10080 block5a_se_reduce[0][0]
__________________________________________________________________________________________________
block5a_se_excite (Multiply) (None, None, None, 4 0 block5a_activation[0][0]
block5a_se_expand[0][0]
__________________________________________________________________________________________________
block5a_project_conv (Conv2D) (None, None, None, 1 53760 block5a_se_excite[0][0]
__________________________________________________________________________________________________
block5a_project_bn (BatchNormal (None, None, None, 1 448 block5a_project_conv[0][0]
__________________________________________________________________________________________________
block5b_expand_conv (Conv2D) (None, None, None, 6 75264 block5a_project_bn[0][0]
__________________________________________________________________________________________________
block5b_expand_bn (BatchNormali (None, None, None, 6 2688 block5b_expand_conv[0][0]
__________________________________________________________________________________________________
block5b_expand_activation (Acti (None, None, None, 6 0 block5b_expand_bn[0][0]
__________________________________________________________________________________________________
block5b_dwconv (DepthwiseConv2D (None, None, None, 6 16800 block5b_expand_activation[0][0]
__________________________________________________________________________________________________
block5b_bn (BatchNormalization) (None, None, None, 6 2688 block5b_dwconv[0][0]
__________________________________________________________________________________________________
block5b_activation (Activation) (None, None, None, 6 0 block5b_bn[0][0]
__________________________________________________________________________________________________
block5b_se_squeeze (GlobalAvera (None, 672) 0 block5b_activation[0][0]
__________________________________________________________________________________________________
block5b_se_reshape (Reshape) (None, 1, 1, 672) 0 block5b_se_squeeze[0][0]
__________________________________________________________________________________________________
block5b_se_reduce (Conv2D) (None, 1, 1, 28) 18844 block5b_se_reshape[0][0]
__________________________________________________________________________________________________
block5b_se_expand (Conv2D) (None, 1, 1, 672) 19488 block5b_se_reduce[0][0]
__________________________________________________________________________________________________
block5b_se_excite (Multiply) (None, None, None, 6 0 block5b_activation[0][0]
block5b_se_expand[0][0]
__________________________________________________________________________________________________
block5b_project_conv (Conv2D) (None, None, None, 1 75264 block5b_se_excite[0][0]
__________________________________________________________________________________________________
block5b_project_bn (BatchNormal (None, None, None, 1 448 block5b_project_conv[0][0]
__________________________________________________________________________________________________
block5b_drop (Dropout) (None, None, None, 1 0 block5b_project_bn[0][0]
__________________________________________________________________________________________________
block5b_add (Add) (None, None, None, 1 0 block5b_drop[0][0]
block5a_project_bn[0][0]
__________________________________________________________________________________________________
block5c_expand_conv (Conv2D) (None, None, None, 6 75264 block5b_add[0][0]
__________________________________________________________________________________________________
block5c_expand_bn (BatchNormali (None, None, None, 6 2688 block5c_expand_conv[0][0]
__________________________________________________________________________________________________
block5c_expand_activation (Acti (None, None, None, 6 0 block5c_expand_bn[0][0]
__________________________________________________________________________________________________
block5c_dwconv (DepthwiseConv2D (None, None, None, 6 16800 block5c_expand_activation[0][0]
__________________________________________________________________________________________________
block5c_bn (BatchNormalization) (None, None, None, 6 2688 block5c_dwconv[0][0]
__________________________________________________________________________________________________
block5c_activation (Activation) (None, None, None, 6 0 block5c_bn[0][0]
__________________________________________________________________________________________________
block5c_se_squeeze (GlobalAvera (None, 672) 0 block5c_activation[0][0]
__________________________________________________________________________________________________
block5c_se_reshape (Reshape) (None, 1, 1, 672) 0 block5c_se_squeeze[0][0]
__________________________________________________________________________________________________
block5c_se_reduce (Conv2D) (None, 1, 1, 28) 18844 block5c_se_reshape[0][0]
__________________________________________________________________________________________________
block5c_se_expand (Conv2D) (None, 1, 1, 672) 19488 block5c_se_reduce[0][0]
__________________________________________________________________________________________________
block5c_se_excite (Multiply) (None, None, None, 6 0 block5c_activation[0][0]
block5c_se_expand[0][0]
__________________________________________________________________________________________________
block5c_project_conv (Conv2D) (None, None, None, 1 75264 block5c_se_excite[0][0]
__________________________________________________________________________________________________
block5c_project_bn (BatchNormal (None, None, None, 1 448 block5c_project_conv[0][0]
__________________________________________________________________________________________________
block5c_drop (Dropout) (None, None, None, 1 0 block5c_project_bn[0][0]
__________________________________________________________________________________________________
block5c_add (Add) (None, None, None, 1 0 block5c_drop[0][0]
block5b_add[0][0]
__________________________________________________________________________________________________
block5d_expand_conv (Conv2D) (None, None, None, 6 75264 block5c_add[0][0]
__________________________________________________________________________________________________
block5d_expand_bn (BatchNormali (None, None, None, 6 2688 block5d_expand_conv[0][0]
__________________________________________________________________________________________________
block5d_expand_activation (Acti (None, None, None, 6 0 block5d_expand_bn[0][0]
__________________________________________________________________________________________________
block5d_dwconv (DepthwiseConv2D (None, None, None, 6 16800 block5d_expand_activation[0][0]
__________________________________________________________________________________________________
block5d_bn (BatchNormalization) (None, None, None, 6 2688 block5d_dwconv[0][0]
__________________________________________________________________________________________________
block5d_activation (Activation) (None, None, None, 6 0 block5d_bn[0][0]
__________________________________________________________________________________________________
block5d_se_squeeze (GlobalAvera (None, 672) 0 block5d_activation[0][0]
__________________________________________________________________________________________________
block5d_se_reshape (Reshape) (None, 1, 1, 672) 0 block5d_se_squeeze[0][0]
__________________________________________________________________________________________________
block5d_se_reduce (Conv2D) (None, 1, 1, 28) 18844 block5d_se_reshape[0][0]
__________________________________________________________________________________________________
block5d_se_expand (Conv2D) (None, 1, 1, 672) 19488 block5d_se_reduce[0][0]
__________________________________________________________________________________________________
block5d_se_excite (Multiply) (None, None, None, 6 0 block5d_activation[0][0]
block5d_se_expand[0][0]
__________________________________________________________________________________________________
block5d_project_conv (Conv2D) (None, None, None, 1 75264 block5d_se_excite[0][0]
__________________________________________________________________________________________________
block5d_project_bn (BatchNormal (None, None, None, 1 448 block5d_project_conv[0][0]
__________________________________________________________________________________________________
block5d_drop (Dropout) (None, None, None, 1 0 block5d_project_bn[0][0]
__________________________________________________________________________________________________
block5d_add (Add) (None, None, None, 1 0 block5d_drop[0][0]
block5c_add[0][0]
__________________________________________________________________________________________________
block6a_expand_conv (Conv2D) (None, None, None, 6 75264 block5d_add[0][0]
__________________________________________________________________________________________________
block6a_expand_bn (BatchNormali (None, None, None, 6 2688 block6a_expand_conv[0][0]
__________________________________________________________________________________________________
block6a_expand_activation (Acti (None, None, None, 6 0 block6a_expand_bn[0][0]
__________________________________________________________________________________________________
block6a_dwconv_pad (ZeroPadding (None, None, None, 6 0 block6a_expand_activation[0][0]
__________________________________________________________________________________________________
block6a_dwconv (DepthwiseConv2D (None, None, None, 6 16800 block6a_dwconv_pad[0][0]
__________________________________________________________________________________________________
block6a_bn (BatchNormalization) (None, None, None, 6 2688 block6a_dwconv[0][0]
__________________________________________________________________________________________________
block6a_activation (Activation) (None, None, None, 6 0 block6a_bn[0][0]
__________________________________________________________________________________________________
block6a_se_squeeze (GlobalAvera (None, 672) 0 block6a_activation[0][0]
__________________________________________________________________________________________________
block6a_se_reshape (Reshape) (None, 1, 1, 672) 0 block6a_se_squeeze[0][0]
__________________________________________________________________________________________________
block6a_se_reduce (Conv2D) (None, 1, 1, 28) 18844 block6a_se_reshape[0][0]
__________________________________________________________________________________________________
block6a_se_expand (Conv2D) (None, 1, 1, 672) 19488 block6a_se_reduce[0][0]
__________________________________________________________________________________________________
block6a_se_excite (Multiply) (None, None, None, 6 0 block6a_activation[0][0]
block6a_se_expand[0][0]
__________________________________________________________________________________________________
block6a_project_conv (Conv2D) (None, None, None, 1 129024 block6a_se_excite[0][0]
__________________________________________________________________________________________________
block6a_project_bn (BatchNormal (None, None, None, 1 768 block6a_project_conv[0][0]
__________________________________________________________________________________________________
block6b_expand_conv (Conv2D) (None, None, None, 1 221184 block6a_project_bn[0][0]
__________________________________________________________________________________________________
block6b_expand_bn (BatchNormali (None, None, None, 1 4608 block6b_expand_conv[0][0]
__________________________________________________________________________________________________
block6b_expand_activation (Acti (None, None, None, 1 0 block6b_expand_bn[0][0]
__________________________________________________________________________________________________
block6b_dwconv (DepthwiseConv2D (None, None, None, 1 28800 block6b_expand_activation[0][0]
__________________________________________________________________________________________________
block6b_bn (BatchNormalization) (None, None, None, 1 4608 block6b_dwconv[0][0]
__________________________________________________________________________________________________
block6b_activation (Activation) (None, None, None, 1 0 block6b_bn[0][0]
__________________________________________________________________________________________________
block6b_se_squeeze (GlobalAvera (None, 1152) 0 block6b_activation[0][0]
__________________________________________________________________________________________________
block6b_se_reshape (Reshape) (None, 1, 1, 1152) 0 block6b_se_squeeze[0][0]
__________________________________________________________________________________________________
block6b_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block6b_se_reshape[0][0]
__________________________________________________________________________________________________
block6b_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block6b_se_reduce[0][0]
__________________________________________________________________________________________________
block6b_se_excite (Multiply) (None, None, None, 1 0 block6b_activation[0][0]
block6b_se_expand[0][0]
__________________________________________________________________________________________________
block6b_project_conv (Conv2D) (None, None, None, 1 221184 block6b_se_excite[0][0]
__________________________________________________________________________________________________
block6b_project_bn (BatchNormal (None, None, None, 1 768 block6b_project_conv[0][0]
__________________________________________________________________________________________________
block6b_drop (Dropout) (None, None, None, 1 0 block6b_project_bn[0][0]
__________________________________________________________________________________________________
block6b_add (Add) (None, None, None, 1 0 block6b_drop[0][0]
block6a_project_bn[0][0]
__________________________________________________________________________________________________
block6c_expand_conv (Conv2D) (None, None, None, 1 221184 block6b_add[0][0]
__________________________________________________________________________________________________
block6c_expand_bn (BatchNormali (None, None, None, 1 4608 block6c_expand_conv[0][0]
__________________________________________________________________________________________________
block6c_expand_activation (Acti (None, None, None, 1 0 block6c_expand_bn[0][0]
__________________________________________________________________________________________________
block6c_dwconv (DepthwiseConv2D (None, None, None, 1 28800 block6c_expand_activation[0][0]
__________________________________________________________________________________________________
block6c_bn (BatchNormalization) (None, None, None, 1 4608 block6c_dwconv[0][0]
__________________________________________________________________________________________________
block6c_activation (Activation) (None, None, None, 1 0 block6c_bn[0][0]
__________________________________________________________________________________________________
block6c_se_squeeze (GlobalAvera (None, 1152) 0 block6c_activation[0][0]
__________________________________________________________________________________________________
block6c_se_reshape (Reshape) (None, 1, 1, 1152) 0 block6c_se_squeeze[0][0]
__________________________________________________________________________________________________
block6c_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block6c_se_reshape[0][0]
__________________________________________________________________________________________________
block6c_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block6c_se_reduce[0][0]
__________________________________________________________________________________________________
block6c_se_excite (Multiply) (None, None, None, 1 0 block6c_activation[0][0]
block6c_se_expand[0][0]
__________________________________________________________________________________________________
block6c_project_conv (Conv2D) (None, None, None, 1 221184 block6c_se_excite[0][0]
__________________________________________________________________________________________________
block6c_project_bn (BatchNormal (None, None, None, 1 768 block6c_project_conv[0][0]
__________________________________________________________________________________________________
block6c_drop (Dropout) (None, None, None, 1 0 block6c_project_bn[0][0]
__________________________________________________________________________________________________
block6c_add (Add) (None, None, None, 1 0 block6c_drop[0][0]
block6b_add[0][0]
__________________________________________________________________________________________________
block6d_expand_conv (Conv2D) (None, None, None, 1 221184 block6c_add[0][0]
__________________________________________________________________________________________________
block6d_expand_bn (BatchNormali (None, None, None, 1 4608 block6d_expand_conv[0][0]
__________________________________________________________________________________________________
block6d_expand_activation (Acti (None, None, None, 1 0 block6d_expand_bn[0][0]
__________________________________________________________________________________________________
block6d_dwconv (DepthwiseConv2D (None, None, None, 1 28800 block6d_expand_activation[0][0]
__________________________________________________________________________________________________
block6d_bn (BatchNormalization) (None, None, None, 1 4608 block6d_dwconv[0][0]
__________________________________________________________________________________________________
block6d_activation (Activation) (None, None, None, 1 0 block6d_bn[0][0]
__________________________________________________________________________________________________
block6d_se_squeeze (GlobalAvera (None, 1152) 0 block6d_activation[0][0]
__________________________________________________________________________________________________
block6d_se_reshape (Reshape) (None, 1, 1, 1152) 0 block6d_se_squeeze[0][0]
__________________________________________________________________________________________________
block6d_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block6d_se_reshape[0][0]
__________________________________________________________________________________________________
block6d_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block6d_se_reduce[0][0]
__________________________________________________________________________________________________
block6d_se_excite (Multiply) (None, None, None, 1 0 block6d_activation[0][0]
block6d_se_expand[0][0]
__________________________________________________________________________________________________
block6d_project_conv (Conv2D) (None, None, None, 1 221184 block6d_se_excite[0][0]
__________________________________________________________________________________________________
block6d_project_bn (BatchNormal (None, None, None, 1 768 block6d_project_conv[0][0]
__________________________________________________________________________________________________
block6d_drop (Dropout) (None, None, None, 1 0 block6d_project_bn[0][0]
__________________________________________________________________________________________________
block6d_add (Add) (None, None, None, 1 0 block6d_drop[0][0]
block6c_add[0][0]
__________________________________________________________________________________________________
block6e_expand_conv (Conv2D) (None, None, None, 1 221184 block6d_add[0][0]
__________________________________________________________________________________________________
block6e_expand_bn (BatchNormali (None, None, None, 1 4608 block6e_expand_conv[0][0]
__________________________________________________________________________________________________
block6e_expand_activation (Acti (None, None, None, 1 0 block6e_expand_bn[0][0]
__________________________________________________________________________________________________
block6e_dwconv (DepthwiseConv2D (None, None, None, 1 28800 block6e_expand_activation[0][0]
__________________________________________________________________________________________________
block6e_bn (BatchNormalization) (None, None, None, 1 4608 block6e_dwconv[0][0]
__________________________________________________________________________________________________
block6e_activation (Activation) (None, None, None, 1 0 block6e_bn[0][0]
__________________________________________________________________________________________________
block6e_se_squeeze (GlobalAvera (None, 1152) 0 block6e_activation[0][0]
__________________________________________________________________________________________________
block6e_se_reshape (Reshape) (None, 1, 1, 1152) 0 block6e_se_squeeze[0][0]
__________________________________________________________________________________________________
block6e_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block6e_se_reshape[0][0]
__________________________________________________________________________________________________
block6e_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block6e_se_reduce[0][0]
__________________________________________________________________________________________________
block6e_se_excite (Multiply) (None, None, None, 1 0 block6e_activation[0][0]
block6e_se_expand[0][0]
__________________________________________________________________________________________________
block6e_project_conv (Conv2D) (None, None, None, 1 221184 block6e_se_excite[0][0]
__________________________________________________________________________________________________
block6e_project_bn (BatchNormal (None, None, None, 1 768 block6e_project_conv[0][0]
__________________________________________________________________________________________________
block6e_drop (Dropout) (None, None, None, 1 0 block6e_project_bn[0][0]
__________________________________________________________________________________________________
block6e_add (Add) (None, None, None, 1 0 block6e_drop[0][0]
block6d_add[0][0]
__________________________________________________________________________________________________
block7a_expand_conv (Conv2D) (None, None, None, 1 221184 block6e_add[0][0]
__________________________________________________________________________________________________
block7a_expand_bn (BatchNormali (None, None, None, 1 4608 block7a_expand_conv[0][0]
__________________________________________________________________________________________________
block7a_expand_activation (Acti (None, None, None, 1 0 block7a_expand_bn[0][0]
__________________________________________________________________________________________________
block7a_dwconv (DepthwiseConv2D (None, None, None, 1 10368 block7a_expand_activation[0][0]
__________________________________________________________________________________________________
block7a_bn (BatchNormalization) (None, None, None, 1 4608 block7a_dwconv[0][0]
__________________________________________________________________________________________________
block7a_activation (Activation) (None, None, None, 1 0 block7a_bn[0][0]
__________________________________________________________________________________________________
block7a_se_squeeze (GlobalAvera (None, 1152) 0 block7a_activation[0][0]
__________________________________________________________________________________________________
block7a_se_reshape (Reshape) (None, 1, 1, 1152) 0 block7a_se_squeeze[0][0]
__________________________________________________________________________________________________
block7a_se_reduce (Conv2D) (None, 1, 1, 48) 55344 block7a_se_reshape[0][0]
__________________________________________________________________________________________________
block7a_se_expand (Conv2D) (None, 1, 1, 1152) 56448 block7a_se_reduce[0][0]
__________________________________________________________________________________________________
block7a_se_excite (Multiply) (None, None, None, 1 0 block7a_activation[0][0]
block7a_se_expand[0][0]
__________________________________________________________________________________________________
block7a_project_conv (Conv2D) (None, None, None, 3 368640 block7a_se_excite[0][0]
__________________________________________________________________________________________________
block7a_project_bn (BatchNormal (None, None, None, 3 1280 block7a_project_conv[0][0]
__________________________________________________________________________________________________
block7b_expand_conv (Conv2D) (None, None, None, 1 614400 block7a_project_bn[0][0]
__________________________________________________________________________________________________
block7b_expand_bn (BatchNormali (None, None, None, 1 7680 block7b_expand_conv[0][0]
__________________________________________________________________________________________________
block7b_expand_activation (Acti (None, None, None, 1 0 block7b_expand_bn[0][0]
__________________________________________________________________________________________________
block7b_dwconv (DepthwiseConv2D (None, None, None, 1 17280 block7b_expand_activation[0][0]
__________________________________________________________________________________________________
block7b_bn (BatchNormalization) (None, None, None, 1 7680 block7b_dwconv[0][0]
__________________________________________________________________________________________________
block7b_activation (Activation) (None, None, None, 1 0 block7b_bn[0][0]
__________________________________________________________________________________________________
block7b_se_squeeze (GlobalAvera (None, 1920) 0 block7b_activation[0][0]
__________________________________________________________________________________________________
block7b_se_reshape (Reshape) (None, 1, 1, 1920) 0 block7b_se_squeeze[0][0]
__________________________________________________________________________________________________
block7b_se_reduce (Conv2D) (None, 1, 1, 80) 153680 block7b_se_reshape[0][0]
__________________________________________________________________________________________________
block7b_se_expand (Conv2D) (None, 1, 1, 1920) 155520 block7b_se_reduce[0][0]
__________________________________________________________________________________________________
block7b_se_excite (Multiply) (None, None, None, 1 0 block7b_activation[0][0]
block7b_se_expand[0][0]
__________________________________________________________________________________________________
block7b_project_conv (Conv2D) (None, None, None, 3 614400 block7b_se_excite[0][0]
__________________________________________________________________________________________________
block7b_project_bn (BatchNormal (None, None, None, 3 1280 block7b_project_conv[0][0]
__________________________________________________________________________________________________
block7b_drop (Dropout) (None, None, None, 3 0 block7b_project_bn[0][0]
__________________________________________________________________________________________________
block7b_add (Add) (None, None, None, 3 0 block7b_drop[0][0]
block7a_project_bn[0][0]
__________________________________________________________________________________________________
top_conv (Conv2D) (None, None, None, 1 409600 block7b_add[0][0]
__________________________________________________________________________________________________
top_bn (BatchNormalization) (None, None, None, 1 5120 top_conv[0][0]
__________________________________________________________________________________________________
top_activation (Activation) (None, None, None, 1 0 top_bn[0][0]
==================================================================================================
Total params: 6,575,239
Trainable params: 0
Non-trainable params: 6,575,239
__________________________________________________________________________________________________
print(str(next(x for x, val in enumerate(model_effnet.layers[2].layers) if val.name == 'block7b_expand_conv')))
321
model_effnet.layers[2].trainable = True
fine_tune_from = 321
for layer in model_effnet.layers[2].layers[:fine_tune_from]:
layer.trainable = False
for layer in model_effnet.layers[2].layers[fine_tune_from:]:
if layer.name[-3:] == "_bn":
layer.trainable = False
print(f"Layer {layer.name} won't be trained.")
model_effnet.compile(
optimizer=tf.keras.optimizers.Adam(lr=1e-5),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['binary_accuracy'],
)
initial_epochs = len(history_effnet.epoch)
finetune_epochs = 40
total_epochs = initial_epochs + finetune_epochs
history_fine = model_effnet.fit(
train_ds,
validation_data=val_ds,
epochs=total_epochs,
initial_epoch=history_effnet.epoch[-1],
verbose=False
)
Layer block7b_expand_bn won't be trained. Layer block7b_bn won't be trained. Layer block7b_project_bn won't be trained. Layer top_bn won't be trained.
model_effnet.evaluate(val_ds, return_dict=True)
14/14 [==============================] - 2s 119ms/step - loss: 0.1572 - binary_accuracy: 0.9365
{'loss': 0.15715086460113525, 'binary_accuracy': 0.9364705681800842}
effnet_hist_fine = history_fine.history
plot_scores({key: val + effnet_hist_fine[key] for (key, val) in effnet_hist.items() if key in ['loss', 'val_loss', 'binary_accuracy', 'val_binary_accuracy']}, 'binary_accuracy', finetune=initial_epochs)
We don't really see much gain in terms of accuracy, but there is a relative gain of ~15% in performance in terms of the loss. However, it seems that if we were to finetune for much further we would be at risk of overfitting. Hence, we'll call it a day and move on to investigate what our network has actually learnt. First though, let's quickly check the predictions for our test set again.
predictions = (model_effnet.predict(test_ds) > 0.5).ravel().astype(int)
plt.figure(figsize=(10, 10))
for image, label in test_ds.take(1):
for i in range(len(predictions)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image[i].numpy().astype("uint8"))
plt.title(f"{class_names[predictions[i]]}")
plt.axis(False)
As a brief aside, there is currently a bug in Keras' implementation of EfficientNet and the model.save_model function when used with custom gradients. It is still possible to load the saved model and do inference (with a bunch of warnings thrown at you), but any further training has a possibility of failing. This issue is tracked at
https://github.com/tensorflow/tensorflow/issues/40166#issuecomment-756702752, where a workaround is also provided. Since we are finetuning our model, we have to modify the code slightly.
model_effnet.save_weights('model_effnet_weights.h5')
symbolic_weights = getattr(model_effnet.optimizer, 'weights')
weight_values = tf.keras.backend.batch_get_value(symbolic_weights)
with open('model_effnet_optimizer.pkl', 'wb') as f:
pickle.dump(weight_values, f)
model_effnet = build_model()
model_effnet.layers[2].trainable = True
fine_tune_from = 321
for layer in model_effnet.layers[2].layers[:fine_tune_from]:
layer.trainable = False
for layer in model_effnet.layers[2].layers[fine_tune_from:]:
if layer.name[-3:] == "_bn":
layer.trainable = False
model_effnet.compile(
optimizer=tf.keras.optimizers.Adam(lr=1e-5),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['binary_accuracy']
)
model_effnet.load_weights('model_effnet_weights.h5')
grad_vars = model_effnet.trainable_weights
zero_grads = [tf.zeros_like(w) for w in grad_vars]
model_effnet.optimizer.apply_gradients(zip(zero_grads, grad_vars))
with open('model_effnet_optimizer.pkl', 'rb') as f:
weight_values = pickle.load(f)
model_effnet.optimizer.set_weights(weight_values)
Our final task is to investigate our model and to determine whether it has actually learnt what we would expect from it. There are multiple ways to do this including saliency maps, class activation maps, class model visualisation and inspection of specific filters based on image patches with the highest activation, to name just a few. For now, we'll focus on saliency maps and class model visualisation and leave the rest for future work. We'll also show how to use saliency maps to perform crude object localisation. We verify our observations with the help of two existing methods for model explainablity: SHAP values and lime models.
Class saliency maps were introduced by Simonyan et al. (2014). The basic idea is simple: we compute the derivative of the output of our network with respect to the input. This tells us which pixels in the original image would have the largest effect on the output (i.e. the class posterior) if we were to alter them by a small amount. Since our pictures have 3 channels, we choose the max of the absolute values across the channels for each component of the derivative to get a single channel output.
Here is a simple implementation of the saliency map computation with the help of tf.GradientTape.
from mpl_toolkits.axes_grid1 import make_axes_locatable
input_img = PIL.Image.open("test/dachshund/albert.jpg").resize((IMG_HEIGHT, IMG_WIDTH))
def saliency(img):
x_sal = tf.convert_to_tensor(np.array(img), dtype=tf.float32)
x_sal = tf.expand_dims(x_sal, axis=0)
with tf.GradientTape() as tape:
tape.watch(x_sal)
output = model_effnet(x_sal, training=False)
gradients = tape.gradient(output, x_sal)
g = np.absolute(gradients[0].numpy())
g = np.max(g, axis=2)
return g
g = saliency(input_img)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,5), constrained_layout=True)
ax1.imshow(input_img)
ax1.axis(False)
im = ax2.imshow(g, cmap='inferno')
ax2.axis(False)
divider = make_axes_locatable(ax2)
box = ax2.get_position()
fig.colorbar(im, ax=ax2, shrink=0.72)
plt.show()
In the original paper the authors mention how saliency maps can be used as part of a more sophisticated localisation algorithm (which falls far short of modern methods, though). The idea is that by thresholding on different quantiles of the saliency map distribution (30% and 95% to be exact), we can separate regions of the image into foreground and background pixels. Then, by fitting Gaussian mixture models to both parts, we obtain a probabilistic representation for either part of the picture. It is then possible to apply existing methods in computer vision to obtain good results. We won't go further into that here, but will simply show a quick example on how to fit the Gaussian models:
from sklearn.mixture import GaussianMixture
g = saliency(input_img)
grad = (g-g.min())/g.ptp()
g_95 = np.where(grad > np.quantile(grad, 0.95), grad, 0)
g_30 = np.where(grad < np.quantile(grad, 0.3), grad, 0)
fg_model = GaussianMixture(n_components=3, random_state=SEED)
fg_model.fit(g_95.reshape(-1,1))
bg_model = GaussianMixture(n_components=3, random_state=SEED)
bg_model.fit(g_30.reshape(-1,1))
fg_threshold = np.mean(fg_model.means_)
fg = grad > fg_threshold
bg_threshold = np.mean(bg_model.means_)
bg = grad < bg_threshold
plt.subplot(121)
plt.imshow(fg, cmap='Blues')
plt.title('Foreground')
plt.axis(False)
plt.subplot(122)
plt.imshow(bg, cmap='Reds')
plt.title('Background')
plt.axis(False)
plt.show()
What we want to do instead is a lot simpler. We'll treat the saliency map as a mass distribution over the original image and use it to find a bounding box, which captures an optimal portion of the total mass relative to its area. More precisely, we first compute the centre of mass of the full mass distribution, located at, say, (c_x, c_y), and start fitting bounding boxes relative to it according to the following algorithm:
(c_x, c_y) of a given saliency map.R with a relative centre (c_x, c_y) which fits the full image. By a relative centre we mean that if we adjust the width of any pair of sides of R, then (c_x, c_y) has to stay within the adjusted rectangle.R' by moving the chosen side of R perpendicularly towards its centroid so that the relative mass inside R' is tol of the total mass in R (a good default value for our data seems to be around tol=0.90).R:=R' and continue from step 3. with the next chosen side.R as R_final.Notice that the remaining total mass is adjusted in step 4. after the assignment. With this method we obtain quite stable bounding boxes from our saliency maps and they seem to be somewhat accurate as long as the classified object doesn't fully occupy the original image. To ensure stability, we also perform some additional smoothing in two ways: first, run the algorithm for each possible ordering of sides (i.e. we get 24 rectangles) and use their mean as R_final, and second, in step 3. we choose the final side length only after smoothing the curve for the change of mass (when moving that side).
We've created a simple class Rect (see rel_rect.py), which allows us to work with rectangles with a relative centre. The constructor of Rect takes 6 arguments c_x, c_y, r_w, b_h, l_w, t_h, i.e. the coordinates of the relative centre, right width of the rectangle, bottom height, left width and top height, respectively. The class then coerces the rectangle so that it is guaranteed to fit inside our initial image (of size 224x224) and exposes attributes Rect.x_left, Rect.x_right, Rect.y_top and Rect.y_bot for the coordinates of each of the sides. We can then adjust these sides separately with the methods Rect.set_rw(r_w), Rect.set_bh(b_h) etc. After each adjustment the class automatically computes new values for all the properties of the rectangle.
from rel_rect import Rect
We need a simple function to compute the mass of the saliency map. This is nothing but the sum of the weights (which have been normalised to lie between 0 and 1) within the given rectangle. The function get_mass below takes a positional argument x, which is the full saliency map, and an optional keyword argument r, which is the rectangle whose mass we want to compute. If no rectangle is provided, we return the total mass of the saliency map.
def get_mass(x, r=None):
if r is None:
return np.sum(x)
else:
return np.sum(x.T[r.x_left : r.x_right + 1, r.y_top : r.y_bot + 1])
def get_density(x, r):
return np.mean(x.T[r.x_left : r.x_right, r.y_top : r.y_bot])
We can then implement the algorithm we described as follows. It consists of three functions: adjust_side, adjust_rec and fit_recs. adjust_side does the brunt of the work. For a given rectangle it adjusts the location of the given side so that the relative mass is approximately equal to tol (after smoothing). Note that this adjustment is done in place. adjust_rec then simply calls adjust_side repeatedly for a given ordering of the sides for the inputted rectangle and returns a new rectangle, which is R_final. Finally, fit_recs is responsible for constructing the saliency map for a given input image, calling adjust_rec on it for all permutations of side orderings and for smoothing the final result.
from itertools import permutations
def adjust_side(x, rec, side, tol=0.9, smoothing=5):
# modifies rec in place
if not(side in ['r_w', 'l_w', 't_h', 'b_h']):
raise ValueError("Invalid length name to Rectangle.")
if rec is None:
raise ValueError("Got empty Rectangle.")
remaining_mass = get_mass(x, rec)
lengths = np.arange(1, rec.__dict__[side] + 1)
mass_prop = []
# if we already have a minimal rect on this side
if lengths is None:
print("Got empty rec")
return False
func_name = {'r_w': 'set_rw', 'l_w': 'set_lw', 't_h': 'set_th', 'b_h': 'set_bh'}[side]
adj = getattr(rec, func_name)
for l in lengths:
adj(l)
mass_prop.append(get_mass(x, rec) / remaining_mass)
# optional smoothing with a running average:
if smoothing > 0 and len(mass_prop) > 0:
smooth = min(smoothing, len(mass_prop))
mass_prop = np.convolve(mass_prop, np.ones(smooth)/smooth, mode='full')[:len(mass_prop)]
# find index where we retain tol% of the initial mass
cutoff = np.argmax(np.array(mass_prop) > tol)
adj(cutoff)
return True
def adjust_rec(x, rec, sides=['r_w', 'l_w', 't_h', 'b_h'], tol=0.9, smooth=5):
if rec is None:
raise ValueError("Got empty Rectangle.")
if not(set(sides) == set(['r_w', 'l_w', 't_h', 'b_h']) and len(sides) == 4):
raise ValueError("Invalid length name to Rectangle.")
# we modify a copy
rec = Rect(rec.c_x, rec.c_y, rec.r_w, rec.b_h, rec.l_w, rec.t_h)
for side in sides:
adjust_side(x, rec, side, tol, smooth)
return rec
def fit_recs(sal_img, tol=0.9, smooth=5):
if sal_img is None:
raise ValueError("Got empty sal_img.")
# Normalise the map as per our assumptions
x = (sal_img-sal_img.min())/np.ptp(sal_img)
c_x = np.argmax(np.mean(x, axis=0))
c_y = np.argmax(np.mean(x, axis=1))
side_order = ['t_h', 'b_h', 'l_w', 'r_w']
perms = permutations(side_order)
# initialise the rec to max possible size irrespective of centroid loc
full_rec = Rect(c_x, c_y, 223-c_x, 223-c_y, c_x, c_y)
recs = []
for p in perms:
recs.append(adjust_rec(x, full_rec, p, tol=tol, smooth=smooth))
# compute the final result as a mean of all the rectangles in 'recs'.
r_w, l_w, t_h, b_h = 0, 0, 0, 0
for r in recs:
r_w += r.r_w
l_w += r.l_w
t_h += r.t_h
b_h += r.b_h
m = len(recs)
r_w /= m
l_w /= m
t_h /= m
b_h /= m
final_rec = Rect(c_x, c_y, r_w, b_h, l_w, t_h)
return recs, final_rec
We can see the finished result below. The centroid is plotted as a red dot, each of the 24 R_final rectangles is drawn in red and finally their mean is shown in cyan. We'll see later that quite often the centroid is not located in what we would typically think of as the centre of the object (dog, in this case). However, based on our experiments, this does not in fact seem to be an issue and our final rectangles are fairly stable even if the centroid is very close to the edges. Moreover, our original reasoning for using this form of construction was to ensure a certain reasonable constraint for each rectangle, which should help with stability. We'll soon see that this is indeed what seems to have happened.
from matplotlib.patches import Rectangle
def plot_recs(ax, recs, final_rec):
ax.scatter(final_rec.c_x, final_rec.c_y, c='r')
for r in recs:
ax.add_patch(Rectangle((r.x_left, r.y_top), r.w, r.h, color='red', fill=False))
ax.add_patch(Rectangle((final_rec.x_left, final_rec.y_top), final_rec.w, final_rec.h,
color='cyan', fill=False, linewidth=2))
sal = saliency(input_img)
recs, final_rec = fit_recs(sal, 0.9, smooth=2)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,5), constrained_layout=True)
ax1.imshow(input_img)
ax2.imshow(sal)
for ax in [ax1, ax2]:
plot_recs(ax, recs, final_rec)
ax.axis(False)
plt.show()
Here we see the algorithm applied to each image from our test set. It seems to provide reasonably good results except possibly when the target object occupies the whole image.
n_images = 8
tol = 0.9
n_rows = n_images // 2
n_cols = min(2 * n_images, 4)
fig, axs = plt.subplots(n_rows, n_cols, figsize=(9, n_rows * 2.25), constrained_layout=True)
for img, _ in test_ds.take(1):
for i in range(n_rows):
for j in range(n_cols // 2):
axs[i, 2*j].imshow(img[2*i+j].numpy().astype('uint8'))
axs[i, 2*j].axis(False)
sal = saliency(img[2*i+j].numpy())
axs[i, 2*j+1].imshow(sal)
axs[i, 2*j+1].axis(False)
recs, final_rec = fit_recs(sal, tol, smooth=3)
plot_recs(axs[i, 2*j], recs, final_rec)
plot_recs(axs[i, 2*j+1], recs, final_rec)
plt.show()
There are multiple ways to improve this algorithm. One idea would be to use np.grad to incorporate the rate of change of mass as we adjust the sides, which would allow us to pick the threshold in a more controlled manner. We could also try to use the Gaussian mixtures model for the foreground and background pixels from last section and fit the bounding boxes based on that instead. It would also be interesting to investigate how our method would perform on a more complicated network trained on a bigger dataset (e.g. multiclass models on the ImageNet dataset).
In the same paper the authors also discuss class model visualisation. The basic idea is to again take the gradient of the output with respect to the input and then modify (with gradient ascent) the input image so as to maximise the final class score. Ideally, this procedure should show us what the network perceives as a picture with maximal probability of being a dachshund. More precisely, our learning objective is the following. Denote by $\mathscr{I}$ the input image and let $S_c(\mathscr{I})$ be its linear class score (i.e. the output before sigmoid activation, so $S_c > 0$ corresponds to $c=\text{dachshund}$). We then want to find $\mathscr{I}$, which is a solution of the optimisation problem
$$\mathop{\mathrm{argmax}}_{\mathscr{I}}S_{c}(\mathscr{I}) - \lVert\mathscr{I}\rVert_{2}^{2},$$where the additional $\ell_{2}$-regularisation ensures that the pixel values don't blow up and the final image stays somewhat smooth (by strongly penalising outliers). We do this by initialising $\mathscr{I}=\mathbf{0}$ and performing gradient ascent. We then add the mean $\mathscr{I}_{0}$ of the training set to the solution $\mathscr{I}$ to obtain our class visualisation.
Notice that it is important to use the class score and not the final activation. To see this, consider a multiclass classifier with softmax activation. The final output of such a network for a fixed class $c$ is $e^{S_c}/\sum_{c'}e^{S_{c'}}$. As pointed out in the paper, if we were to maximise this, we could do it simply by minimising the class score for all classes $c'\neq c$, which would not imply that the final image is representative of the features of class $c$. There seems to be an additional reason, which was not mentioned in the paper. In our case we only have 2 classes so the above is not a problem a priori. However, both softmax and sigmoid functions have the same issue that as the activation saturates the derivatives vanish (which is one of the reasons why most deep NNs have pivoted to ReLU activation in hidden layers). But this is precisely what we're optimising for and so if we were to use the class posteriors then it makes sense that the training would slow down considerably. Therefore, we remove the sigmoid activation and consider the linear output of our model instead.
model_effnet_linear = tf.keras.Model(inputs=model_effnet.input, outputs=model_effnet.layers[-2].output)
# Initialise I_0 to be the mean of the training set
I_0 = np.zeros((IMG_HEIGHT, IMG_WIDTH, CHANNELS))
I_0 = tf.convert_to_tensor(I_0, dtype=tf.float32)
n = 0
for img, label in train_ds.unbatch():
n += 1
I_0 += img
I_0 /= n
plt.imshow(I_0.numpy().squeeze().astype('uint8'))
plt.axis(False)
plt.show()
We run the Adam optimiser for 3000 iterations with learning rate 0.3 and epsilon $10^{-12}$. Based on our experiments, it seems that the choice of optimisation algorithm seems to have a big effect on the visuals of the final image. Moreover, we obtain visually better results with the flag training = True since this keeps any BatchNormalisation layers active.
%%time
lbd = 0.01
lr = 0.3
I = tf.Variable(tf.zeros_like(I_0), name='I', dtype=tf.float32)
loss = lambda: -(model_effnet_linear(tf.expand_dims(I, axis=0), training=True)
- lbd * tf.math.reduce_sum(tf.math.square(I / 255.)))
opt = tf.keras.optimizers.Adam(learning_rate=lr, epsilon=1e-11)
steps = 3000
draw = False
save = True
save_freq = 10
save_folder = 'anim'
idx = 0
if not os.path.exists(save_folder):
os.makedirs(save_folder)
for i in range(steps):
if save and i % save_freq == 0:
img = PIL.Image.fromarray(tf.squeeze(I).numpy().astype('uint8'))
img.save(os.path.join(save_folder, f'{idx:04}.png'), 'PNG')
idx += 1
if i % 10 == 9:
if draw:
clear_output()
plt.imshow(tf.squeeze(I).numpy().astype('uint8'))
plt.axis(False)
plt.show()
print(f"{i+1}/{steps}", end="\r", flush=True)
opt.minimize(loss, var_list=[I])
tf.debugging.assert_all_finite(I, str(i))
Wall time: 14min 34s
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))
ax1.imshow(tf.squeeze(I+I_0).numpy().astype('uint8'))
ax1.axis(False)
ax2.imshow(tf.squeeze(I).numpy().astype('uint8'))
ax2.axis(False)
fig.suptitle('Class model for $c$=dachshund', fontsize=16)
plt.show()
On the right is the pure optimised image $\mathscr{I}$ (normalised) and on the left we have the final class model $\mathscr{I}+\mathscr{I}_{0}$. The results are not as pleasing as we expected (cf. the representations that they obtain in the original paper), but it is certainly possible to see that the model seems to be focusing a lot on the shape of the ears which are very unique to dachshunds. Below we show some other class model visualisations (which might present higher-level features) for the dachshund class that were obtained with different training parameters:

The results for the class "other", which contains a huge variety of breeds, are less pleasing:

Video of backpropagation for dachshund class model:
model_effnet_linear.predict(tf.expand_dims(I, axis=0)).item()
38.9136962890625
Our complicated final image $\mathscr{I}$ obtains an extremely high linear score of ~38.9, which, after sigmoid activation, is equal to 1.0 to 16 decimal places. Notice that in the original paper the features of each particular class are more prominent in their class model visualisations. It is unclear whether this was generally true or if these examples were cherry-picked out of the 1000 classes. However, we present a few possible explanations as to why our results seem a bit different. First of all, as we pointed out above, the optimisation problem is extremely tricky. This is because we don't want to converge to just any local optimum, but one which provides us with some visible high-level features that humans can identify when looking at the picture. It is therefore difficult to find the right combination of hyperparameters to achieve this. Second, since we trained a model for a binary classification task between two very similar types of objects, it is plausible that our model focuses heavily on smaller details as opposed to what typical multiclass classifiers might do. It would be possible to determine whether this is indeed the case by examining which image patches activate which filters of the higher convolutional layers in our model. We'll leave that for future work. Another interesting avenue for investigation would be to see how the class models are affected if we keep other training parameters constant but initialise our image $\mathscr{I}$ with some mild (clamped) noise.
In the following cell we have the code we used to create the frames for the animation.
data_folder = 'anim/*png'
out_folder = 'anim_mu'
I_mu = tf.squeeze(I_0).numpy().astype('uint8')
if not os.path.exists(out_folder):
os.makedirs(out_folder)
*imgs, = [PIL.Image.open(f) for f in sorted(glob.glob(data_folder))]
imgs_mu = [PIL.Image.fromarray(np.array(i) + I_mu) for i in imgs]
for idx, img in enumerate(imgs_mu):
comb_img = PIL.Image.new('RGB', (img.width * 2, img.height))
comb_img.paste(img, (0, 0))
comb_img.paste(imgs[idx], (img.width, 0))
comb_img.save(os.path.join(out_folder, f'{idx:04}.png'), 'PNG')
In this section we use existing implementations of two popular methods for explaining machine learning models, namely SHAP and Lime. Both of these methods are model agnostic, which means they will work with any machine learning model. Moreover, the implementations we use can explain many data types, such as tabular, textual or image data, with little effort. The crux of this class of "explainers" is to approximate the complicated deep learning model (in this case) by another model (called the explanation model) that is easier to understand, such as a linear model or a decision tree.
The first method we'll look at is called SHAP and is based on the paper by Lundberg and Lee (2017). The basic idea is to extend the notion of Shapley values to the setting of complex machine learning models. These values, for each individual prediction, are composed of feature weights that describe how much (either positively or negatively) the prediction of a model with that feature included improves on one lacking that particular feature. The sum of these weights is the SHAP value. In general computing these values exactly is extremely computationally intensive (not only because it would require refitting a model at each step), because we need to figure out how much adding a specific feature to any subset of input features changes the prediction so one has to consider the power set of the features and the permutations of the elements in each. The idea proposed in the paper is to approximate this simplified model (with some missing features) by taking expectations and considering $\mathbb{E}(f(z)|z_{S})$, where $f$ is the original model and $z_{S}$ denotes an input with features not in $S$ set to zero. This quantity can then be approximated by assuming that the features are independent and that the model is linear. In the case of image data the raw features are of course the original pixels, but since this number is usually untractably large many explanation models simplify the situation by considering super pixels (neighbourhoods of pixels with similar colour values) or larger number of pixels at once.
For neural networks, SHAP comes with two types of explanation models that accurately approximate the true SHAP values: shap.DeepExplainer and shap.GradientExplainer. Unfortunately, both of these are currently bugged with TF 2 models containing specific types of layers (e.g. global pooling) so we cannot use them. Instead we have to rely on one of the more generic explainers, shap.Explainer, which works for any model, but returns less accurate approximations and with less efficiency.
input_img = PIL.Image.open("test/dachshund/albert.jpg").resize((IMG_HEIGHT, IMG_WIDTH))
x = tf.convert_to_tensor(np.array(input_img), dtype=tf.float32)
x = tf.expand_dims(x, axis=0)
def f(X):
tmp = X.copy()
return model_effnet(tmp)
masker = shap.maskers.Image("inpaint_telea", x[0].shape)
e = shap.Explainer(f, masker)
shap_values = e(np.array(x), max_evals=2000, batch_size=128, silent=True)
shap.image_plot(shap_values)
In the above explanation the colour and intensity of each rectangle shows much it contributed to the final class prediction (i.e. 1 for a dachshund) with red signifying a positive contribution and blue negative. Naturally the above picture shows nothing surprising based on our experiments with saliency maps, save for the fact that the model seems to put great emphasis on the shape of the head and nose. If we were able to use one of the NN-specific explainers we might be able to obtain more insight with higher fidelity explanations. Later we'll remedy this issue with Lime, but first let's look at a few more examples.
x_test = list(test_ds.as_numpy_iterator())[0][0]
shap_values = e(x_test[:3], max_evals=2000, batch_size=128, silent=True)
shap.image_plot(shap_values, width=60)
Notice that in the case of a negative prediction (i.e. 0 for "other") it is the blue squares that encode the features contributing to the model's decision to not classify the dog as a dachshund since we are dealing with binary classification.
In order to get more precise explanations we'll next look at Lime based on the work of Ribeiro, Singh and Guestrin (2016). This predates SHAP somewhat and can be seen as a local approximation to the true SHAP values (see the discussion in the SHAP paper). In practice, Lime looks at the super pixels weighted by their proximity to the region of the picture it is trying to explain. It then fits a linear model, according to these weights, which locally approximates the underlying black box model. One typical issue arising from this approach is that the definition of local and neighbourhood greatly depend on the task or even the particular model at hand and that it is generally difficult to define good default values. Instead, one has to have some baseline understanding for what constitutes a reasonable explanation for the model in order to tune these parameters via trial and error. See here for further discussion. In any case, as we see below this is rather simple for image based models and we gain some further insight that was not evident with shap.
def explain_lime(img, num_samples=500, num_features=10):
e = lime_image.LimeImageExplainer()
ex = e.explain_instance(img, model_effnet.predict, top_labels=1,
hide_color=0, num_samples=num_samples)
temp, mask = ex.get_image_and_mask(ex.top_labels[0], positive_only=False, num_features=num_features,
hide_rest=False)
plt.imshow(mark_boundaries((temp/255) / 2 + 0.5, mask))
explain_lime(x_test[2].astype('double'), num_samples=2000, num_features=15)
As we can see the information provided by Lime is much more precise than by SHAP (mainly because it relies on super pixels rather than fixed regions of the image). In the above the green regions indicate what Lime thinks the underlying model considers important for the final prediction, whereas the red regions denote parts of the image which decrease this confidence. Interestingly, we notice that the model doesn't pay much attention to the head in this prediction (possibly since it blends in with the rest of the body, but it does think that the curve of the lower body, the tail and the shape of the ear are unique identifiers of dachshunds (which definitely is believable!). Let's now look at a negative prediction. As before, the purpose of the colours is reversed in this case since we're trying to predict the negative class.
explain_lime(x_test[1].astype('double'), num_samples=3000, num_features=20)
Here we notice something interesting. The model seems to completely ignore the eyes and the nose of the dog. Perhaps this is one reason why these particular features were not very visible in our class model visualisations which seemed to focus more on the texture of the fur and the curves of the body of the dog. On the other hand, we see an emphasis put on the shape of the lower part of the head, which is definitely distinct for dachshunds versus other breeds, along with the red region under the dog. It's possible that the model is looking at the legs or the paws here, but we cannot immediately conclude this based on the crude explanation above.
Now that we are satisfied with our model, it's time to deploy it to Heroku. Alas, even with all the simplifications we still end up going over the allotted RAM at inference time. In order to get around this, we quantise the model with tf.life to use a smaller precision. A typical default target is to convert the model to use uint8, but this is only optimised to run on ARM architecture. For Heroku we instead quantise to 16 bit floats. The final inference model ends up being only 22 MB in size, whereas the full model exported with tf.keras.models.save_model is well over 100 MB. As a comparison, the ResNet model takes about 42 MB even when quantised.
import pathlib
curdir = pathlib.Path('.')
converter = tf.lite.TFLiteConverter.from_keras_model(model_effnet)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
tflite_model_quant_file = curdir/'-'
tflite_model_quant_file.write_bytes(tflite_quant_model)
In order to avoid using Heroku's finicky filesystem, we store the user uploaded file only in memory and pass it as a binary buffer to the html template. The complete inference process with the quantised model looks something like the following:
from io import BytesIO
import base64 as b64
# When a new dyno is launched it loads the quantised model and gets the default signature with
interpreter_quant = tf.lite.Interpreter(model_path='-')
predict = interpreter_quant.get_signature_runner()
# We provide Flask with a function which captures POST requests to the website
# and loads the uploaded file to memory with
in_memory_data = BytesIO()
file.seek(0)
file.save(in_memory_data)
file.close()
# Since we want to display the picture to the user along with the result,
# we pass it as base 64 encoded string to the Jinja template
img = Image.open(in_memory_data).resize((img_height, img_width))
buffer = BytesIO()
img.save(buffer, 'JPEG')
encoded_img_data = b64.b64encode(buffer.getvalue())
buffer.close()
# To perform inference we just do. Notice we still need to pass 32 bit floats which then get truncated to 16 bits
tf_img = np.expand_dims(img, axis=0).astype(np.float32)
img.close()
prediction_prob = np.squeeze(predict(input_2=tf_img)['dense']).item()
# Finally we pass the prediction and the decoded image data back to the template as a decoded string.
# To decode the data we use:
encoded_img_data.decode('utf-8')